[Numpy-svn] r3017 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Aug 14 15:13:39 CDT 2006


Author: oliphant
Date: 2006-08-14 15:13:33 -0500 (Mon, 14 Aug 2006)
New Revision: 3017

Added:
   trunk/numpy/lib/tests/test_utils.py
Modified:
   trunk/numpy/lib/utils.py
Log:
Fix ndpointer and add tests from ticket #245

Added: trunk/numpy/lib/tests/test_utils.py
===================================================================
--- trunk/numpy/lib/tests/test_utils.py	2006-08-14 20:01:12 UTC (rev 3016)
+++ trunk/numpy/lib/tests/test_utils.py	2006-08-14 20:13:33 UTC (rev 3017)
@@ -0,0 +1,62 @@
+from numpy.testing import * 
+set_package_path() 
+import numpy as N 
+restore_path() 
+ 
+class test_ndpointer(NumpyTestCase): 
+    def check_dtype(self): 
+        dt = N.intc 
+        p = N.ndpointer(dtype=dt) 
+        self.assert_(p.from_param(N.array([1], dt))) 
+        dt = '<i4' 
+        p = N.ndpointer(dtype=dt) 
+        self.assert_(p.from_param(N.array([1], dt))) 
+        dt = N.dtype('>i4') 
+        p = N.ndpointer(dtype=dt) 
+        p.from_param(N.array([1], dt)) 
+        self.assertRaises(TypeError, p.from_param, 
+                          N.array([1], dt.newbyteorder('swap'))) 
+        dtnames = ['x', 'y'] 
+        dtformats = [N.intc, N.float64] 
+        dtdescr = {'names' : dtnames, 'formats' : dtformats} 
+        dt = N.dtype(dtdescr) 
+        p = N.ndpointer(dtype=dt) 
+        self.assert_(p.from_param(N.zeros((10,), dt))) 
+        samedt = N.dtype(dtdescr) 
+        p = N.ndpointer(dtype=samedt) 
+        self.assert_(p.from_param(N.zeros((10,), dt))) 
+        dt2 = N.dtype(dtdescr, align=True)
+        if dt.itemsize != dt2.itemsize:
+            self.assertRaises(TypeError, p.from_param, N.zeros((10,), dt2))
+        else:
+            self.assert_(p.from_param(N.zeros((10,), dt2)))
+ 
+    def check_ndim(self): 
+        p = N.ndpointer(ndim=0) 
+        self.assert_(p.from_param(N.array(1))) 
+        self.assertRaises(TypeError, p.from_param, N.array([1])) 
+        p = N.ndpointer(ndim=1) 
+        self.assertRaises(TypeError, p.from_param, N.array(1)) 
+        self.assert_(p.from_param(N.array([1]))) 
+        p = N.ndpointer(ndim=2) 
+        self.assert_(p.from_param(N.array([[1]]))) 
+         
+    def check_shape(self): 
+        p = N.ndpointer(shape=(1,2)) 
+        self.assert_(p.from_param(N.array([[1,2]]))) 
+        self.assertRaises(TypeError, p.from_param, N.array([[1],[2]])) 
+        p = N.ndpointer(shape=()) 
+        self.assert_(p.from_param(N.array(1))) 
+ 
+    def check_flags(self): 
+        x = N.array([[1,2,3]], order='F') 
+        p = N.ndpointer(flags='FORTRAN') 
+        self.assert_(p.from_param(x)) 
+        p = N.ndpointer(flags='CONTIGUOUS') 
+        self.assertRaises(TypeError, p.from_param, x) 
+        p = N.ndpointer(flags=x.flags.num) 
+        self.assert_(p.from_param(x)) 
+        self.assertRaises(TypeError, p.from_param, N.array([[1,2,3]])) 
+ 
+if __name__ == "__main__": 
+    NumpyTest().run() 

Modified: trunk/numpy/lib/utils.py
===================================================================
--- trunk/numpy/lib/utils.py	2006-08-14 20:01:12 UTC (rev 3016)
+++ trunk/numpy/lib/utils.py	2006-08-14 20:13:33 UTC (rev 3017)
@@ -2,7 +2,7 @@
 import inspect
 import types
 from numpy.core.numerictypes import obj2sctype, integer
-from numpy.core.multiarray import dtype as _dtype, _flagdict
+from numpy.core.multiarray import dtype as _dtype, _flagdict, flagsobj
 from numpy.core import product, ndarray
 
 __all__ = ['issubclass_', 'get_numpy_include', 'issubsctype',
@@ -101,7 +101,7 @@
             raise TypeError, "array must have %d dimension(s)" % cls._ndim_
         if cls._shape_ is not None \
                and obj.shape != cls._shape_:
-            raise TypeError, "array must have shape %s" % cls._shape_
+            raise TypeError, "array must have shape %s" % str(cls._shape_)
         if cls._flags_ is not None \
                and ((obj.flags.num & cls._flags_) != cls._flags_):
             raise TypeError, "array must have flags %s" % \
@@ -121,9 +121,15 @@
             flags = flags.split(',')
         elif isinstance(flags, (int, integer)):
             num = flags
-            flags = _flags_fromnum(flags)
+            flags = _flags_fromnum(num)
+        elif isinstance(flags, flagsobj):
+            num = flags.num
+            flags = _flags_fromnum(num)
         if num is None:
-            flags = [x.strip().upper() for x in flags]
+            try:
+                flags = [x.strip().upper() for x in flags]
+            except:
+                raise TypeError, "invalid flags specification"
             num = _num_fromflags(flags)
     try:
         return _pointer_type_cache[(dtype, ndim, shape, num)]



More information about the Numpy-svn mailing list