[Numpy-svn] r3210 - trunk/numpy/core/tests

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Sep 22 23:45:16 CDT 2006


Author: oliphant
Date: 2006-09-22 23:45:08 -0500 (Fri, 22 Sep 2006)
New Revision: 3210

Modified:
   trunk/numpy/core/tests/test_regression.py
Log:
Add test for default axis in method and functions.

Modified: trunk/numpy/core/tests/test_regression.py
===================================================================
--- trunk/numpy/core/tests/test_regression.py	2006-09-23 03:47:27 UTC (rev 3209)
+++ trunk/numpy/core/tests/test_regression.py	2006-09-23 04:45:08 UTC (rev 3210)
@@ -411,6 +411,45 @@
         x = N.array((1,2), dtype=dt)
         x = x.byteswap()
         assert(x['one'] > 1 and x['two'] > 2)
+
+    def check_method_args(self, level=rlevel):
+        # Make sure methods and functions have same default axis
+        # keyword and arguments
+        funcs1= ['argmax', 'argmin', 'sum', ('product', 'prod'),
+                 ('sometrue', 'any'),
+                 ('alltrue', 'all'), 'cumsum', ('cumproduct', 'cumprod'),
+                 'ptp', 'cumprod', 'prod', 'std', 'var', 'mean',
+                 'round', 'min', 'max', 'argsort', 'sort']
+        funcs2 = ['compress', 'take', 'repeat']
         
+        for func in funcs1:
+            arr = N.random.rand(8,7)
+            arr2 = arr.copy()
+            if isinstance(func, tuple):
+                func_meth = func[1]
+                func = func[0]
+            else:
+                func_meth = func
+            res1 = getattr(arr, func_meth)()
+            res2 = getattr(N, func)(arr2)
+            if res1 is None:
+                assert abs(arr-res2).max() < 1e-8, func
+            else:
+                assert abs(res1-res2).max() < 1e-8, func
+
+        for func in funcs2:
+            arr1 = N.random.rand(8,7)
+            arr2 = N.random.rand(8,7)
+            res1 = None
+            if func == 'compress':
+                arr1 = arr1.ravel()
+                res1 = getattr(arr2, func)(arr1)
+            else:
+                arr2 = (15*arr2).astype(int).ravel()
+            if res1 is None:
+                res1 = getattr(arr1, func)(arr2)
+            res2 = getattr(N, func)(arr1, arr2)
+            assert abs(res1-res2).max() < 1e-8, func
+        
 if __name__ == "__main__":
     NumpyTest().run()



More information about the Numpy-svn mailing list