[Numpy-svn] r3281 - trunk/numpy/lib

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Oct 6 21:59:35 CDT 2006


Author: oliphant
Date: 2006-10-06 21:59:33 -0500 (Fri, 06 Oct 2006)
New Revision: 3281

Modified:
   trunk/numpy/lib/shape_base.py
Log:
Fix kron to be N-dimensional.

Modified: trunk/numpy/lib/shape_base.py
===================================================================
--- trunk/numpy/lib/shape_base.py	2006-10-07 02:44:41 UTC (rev 3280)
+++ trunk/numpy/lib/shape_base.py	2006-10-07 02:59:33 UTC (rev 3281)
@@ -1,7 +1,7 @@
 __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
            'column_stack','row_stack', 'dstack','array_split','split','hsplit',
            'vsplit','dsplit','apply_over_axes','expand_dims',
-           'apply_along_axis', 'tile', 'kron']
+           'apply_along_axis', 'kron', 'tile']
 
 import numpy.core.numeric as _nx
 from numpy.core.numeric import asarray, zeros, newaxis, outer, \
@@ -542,27 +542,30 @@
 def kron(a,b):
     """kronecker product of a and b
 
-    Kronecker product of two matrices is block matrix
+    Kronecker product of two arrays is block array
     [[ a[ 0 ,0]*b, a[ 0 ,1]*b, ... , a[ 0 ,n-1]*b  ],
      [ ...                                   ...   ],
      [ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b  ]]
     """
     wrapper = _getwrapper(a, b)
-    a = asanyarray(a)
-    b = asanyarray(b)
-    if not (len(a.shape) == len(b.shape) == 2):
-        raise ValueError("a and b must both be two dimensional")
+    b = asanyarray(b)    
+    a = array(a,copy=False,subok=True,ndmin=b.ndim)
+    as = a.shape
+    bs = b.shape
     if not a.flags.contiguous:
-        a = reshape(a, a.shape)
+        a = reshape(a, as)
     if not b.flags.contiguous:
-        b = reshape(b, b.shape)
+        b = reshape(b, bs)
     o = outer(a,b)
-    o=o.reshape(a.shape + b.shape)
-    result = concatenate(concatenate(o, axis=1), axis=1)
+    result = o.reshape(as + bs)
+    axis = a.ndim-1
+    for k in xrange(b.ndim):
+        result = concatenate(result, axis=axis)
     if wrapper is not None:
         result = wrapper(result)
     return result
 
+
 def tile(A, reps):
     """Repeat an array the number of times given in the integer tuple, reps.
 



More information about the Numpy-svn mailing list