[Numpy-svn] r3293 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Oct 9 02:47:10 CDT 2006


Author: oliphant
Date: 2006-10-09 02:47:06 -0500 (Mon, 09 Oct 2006)
New Revision: 3293

Modified:
   trunk/numpy/lib/shape_base.py
   trunk/numpy/lib/tests/test_shape_base.py
Log:
Fix kron for multiple-dimensions.  kron is defined so tile(b, s) is the same as kron(ones(s,b.dtype), b)

Modified: trunk/numpy/lib/shape_base.py
===================================================================
--- trunk/numpy/lib/shape_base.py	2006-10-08 13:16:13 UTC (rev 3292)
+++ trunk/numpy/lib/shape_base.py	2006-10-09 07:47:06 UTC (rev 3293)
@@ -1,7 +1,7 @@
 __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
            'column_stack','row_stack', 'dstack','array_split','split','hsplit',
            'vsplit','dsplit','apply_over_axes','expand_dims',
-           'apply_along_axis', 'kron', 'tile']
+           'apply_along_axis', 'kron', 'tile', 'get_array_wrap']
 
 import numpy.core.numeric as _nx
 from numpy.core.numeric import asarray, zeros, newaxis, outer, \
@@ -526,7 +526,7 @@
         raise ValueError, 'vsplit only works on arrays of 3 or more dimensions'
     return split(ary,indices_or_sections,2)
 
-def _getwrapper(*args):
+def get_array_wrap(*args):
     """Find the wrapper for the array with the highest priority.
 
     In case of ties, leftmost wins. If no wrapper is found, return None
@@ -547,19 +547,28 @@
      [ ...                                   ...   ],
      [ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b  ]]
     """
-    wrapper = _getwrapper(a, b)
+    wrapper = get_array_wrap(a, b)
     b = asanyarray(b)    
     a = array(a,copy=False,subok=True,ndmin=b.ndim)
+    ndb, nda = b.ndim, a.ndim
+    if (nda == 0 or ndb == 0):
+        return a * b
     as = a.shape
     bs = b.shape
     if not a.flags.contiguous:
         a = reshape(a, as)
     if not b.flags.contiguous:
         b = reshape(b, bs)
-    o = outer(a,b)
-    result = o.reshape(as + bs)
-    axis = a.ndim-1
-    for k in xrange(b.ndim):
+    nd = ndb
+    if (ndb != nda):
+        if (ndb > nda):
+            as = (1,)*(ndb-nda) + as
+        else:
+            bs = (1,)*(nda-ndb) + bs
+            nd = nda        
+    result = outer(a,b).reshape(as+bs)
+    axis = nd-1
+    for k in xrange(nd):
         result = concatenate(result, axis=axis)
     if wrapper is not None:
         result = wrapper(result)

Modified: trunk/numpy/lib/tests/test_shape_base.py
===================================================================
--- trunk/numpy/lib/tests/test_shape_base.py	2006-10-08 13:16:13 UTC (rev 3292)
+++ trunk/numpy/lib/tests/test_shape_base.py	2006-10-09 07:47:06 UTC (rev 3293)
@@ -11,8 +11,6 @@
         a = ones((20,10),'d')
         assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
     def check_simple101(self,level=11):
-        # This test causes segmentation fault (Numeric 23.3,23.6,Python 2.3.4)
-        # when enabled and shape(a)[1]>100. See Issue 202.
         a = ones((10,101),'d')
         assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
 
@@ -370,6 +368,7 @@
         assert_equal(type(kron(a,ma)), ndarray) 
         assert_equal(type(kron(ma,a)), myarray) 
 
+
 class test_tile(NumpyTestCase):
     def check_basic(self):
         a = array([0,1,2])
@@ -380,7 +379,19 @@
         assert_equal(tile(b, 2), [[1,2,1,2],[3,4,3,4]])
         assert_equal(tile(b,(2,1)),[[1,2],[3,4],[1,2],[3,4]])
         assert_equal(tile(b,(2,2)),[[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]])
-        
+
+    def check_kroncompare(self):
+        import numpy.random as nr
+	reps=[(2,),(1,2),(2,1),(2,2),(2,3,2),(3,2)]
+        shape=[(3,),(2,3),(3,4,3),(3,2,3),(4,3,2,4),(2,2)]
+        for s in shape:
+            b = nr.randint(0,10,size=s)
+            for r in reps:
+                a = ones(r, b.dtype)
+                large = tile(b, r)
+                klarge = kron(a, b)
+                assert_equal(large, klarge)
+
 # Utility
 
 def compare_results(res,desired):



More information about the Numpy-svn mailing list