[Numpy-svn] r5672 - in branches/gen_ufuncs/numpy/core: . tests

numpy-svn@scip... numpy-svn@scip...
Thu Aug 21 14:46:05 CDT 2008


Author: stefan
Date: 2008-08-21 14:45:02 -0500 (Thu, 21 Aug 2008)
New Revision: 5672

Modified:
   branches/gen_ufuncs/numpy/core/setup.py
   branches/gen_ufuncs/numpy/core/tests/test_ufunc.py
Log:
Add tests [patch by Wenjie Fu and Hans-Andreas Engel].


Modified: branches/gen_ufuncs/numpy/core/setup.py
===================================================================
--- branches/gen_ufuncs/numpy/core/setup.py	2008-08-21 18:05:00 UTC (rev 5671)
+++ branches/gen_ufuncs/numpy/core/setup.py	2008-08-21 19:45:02 UTC (rev 5672)
@@ -339,7 +339,15 @@
                          extra_info = blas_info
                          )
 
+    config.add_extension('umath_tests',
+                         sources = [join('src','umath_tests.c.src'),
+                                    ],
+                         depends = [join('blasdot','cblas.h'),] + deps,
+                         include_dirs = ['blasdot'],
+                         extra_info = blas_info
+                         )
 
+
     config.add_data_dir('tests')
     config.add_data_dir('tests/data')
 

Modified: branches/gen_ufuncs/numpy/core/tests/test_ufunc.py
===================================================================
--- branches/gen_ufuncs/numpy/core/tests/test_ufunc.py	2008-08-21 18:05:00 UTC (rev 5671)
+++ branches/gen_ufuncs/numpy/core/tests/test_ufunc.py	2008-08-21 19:45:02 UTC (rev 5672)
@@ -1,5 +1,7 @@
 import numpy as np
 from numpy.testing import *
+from numpy.random import rand
+import numpy.core.umath_tests as umt
 
 class TestUfunc(TestCase):
     def test_reduceat_shifting_sum(self) :
@@ -229,6 +231,68 @@
         """
         pass
 
+    def test_innerwt(self):
+        a = np.arange(6).reshape((2,3))
+        b = np.arange(10,16).reshape((2,3))
+        w = np.arange(20,26).reshape((2,3))
+        assert_array_equal(umt.innerwt(a,b,w), np.sum(a*b*w,axis=-1))
+        a = np.arange(100,124).reshape((2,3,4))
+        b = np.arange(200,224).reshape((2,3,4))
+        w = np.arange(300,324).reshape((2,3,4))
+        assert_array_equal(umt.innerwt(a,b,w), np.sum(a*b*w,axis=-1))
 
+    def test_matrix_multiply(self):
+        self.compare_matrix_multiply_results(np.long)
+        self.compare_matrix_multiply_results(np.double)
+
+    def compare_matrix_multiply_results(self, tp):
+        d1 = np.array(rand(2,3,4), dtype=tp)
+        d2 = np.array(rand(2,3,4), dtype=tp)
+        msg = "matrix multiply on type %s" % d1.dtype.name
+        
+        def permute_n(n):
+            if n == 1:
+                return ([0],)
+            ret = ()
+            base = permute_n(n-1)
+            for perm in base:
+                for i in xrange(n):
+                    new = perm + [n-1]
+                    new[n-1] = new[i]
+                    new[i] = n-1
+                    ret += (new,)
+            return ret
+        def slice_n(n):
+            if n == 0:
+                return ((),)
+            ret = ()
+            base = slice_n(n-1)
+            for sl in base:
+                ret += (sl+(slice(None),),)
+                ret += (sl+(slice(0,1),),)
+            return ret
+        def broadcastable(s1,s2):
+            return s1 == s2 or s1 == 1 or s2 == 1
+        permute_3 = permute_n(3)
+        slice_3 = slice_n(3) + ((slice(None,None,-1),)*3,)
+
+        ref = True
+        for p1 in permute_3:
+            for p2 in permute_3:
+                for s1 in slice_3:
+                    for s2 in slice_3:
+                        a1 = d1.transpose(p1)[s1]
+                        a2 = d2.transpose(p2)[s2]
+                        ref = ref and a1.base != None and a1.base.base != None
+                        ref = ref and a2.base != None and a2.base.base != None
+                        if broadcastable(a1.shape[-1], a2.shape[-2]) and \
+                           broadcastable(a1.shape[0], a2.shape[0]):
+                            assert_array_almost_equal(umt.matrix_multiply(a1,a2), \
+                                np.sum(a2[...,np.newaxis].swapaxes(-3,-1) * \
+                                       a1[...,np.newaxis,:], axis=-1), \
+                                err_msg = msg+' %s %s' % (str(a1.shape),str(a2.shape)))
+
+        assert_equal(ref, True, err_msg="reference check")
+
 if __name__ == "__main__":
     run_module_suite()



More information about the Numpy-svn mailing list