[Numpy-svn] r3006 - trunk/numpy/lib

numpy-svn at scipy.org numpy-svn at scipy.org
Sun Aug 13 05:03:15 CDT 2006


Author: oliphant
Date: 2006-08-13 05:03:13 -0500 (Sun, 13 Aug 2006)
New Revision: 3006

Modified:
   trunk/numpy/lib/utils.py
Log:
Improve ndpointer to allow shape and flags checking as well.

Modified: trunk/numpy/lib/utils.py
===================================================================
--- trunk/numpy/lib/utils.py	2006-08-13 09:04:39 UTC (rev 3005)
+++ trunk/numpy/lib/utils.py	2006-08-13 10:03:13 UTC (rev 3006)
@@ -1,8 +1,8 @@
 import sys, os
 import inspect
 import types
-from numpy.core.numerictypes import obj2sctype
-from numpy.core.multiarray import dtype
+from numpy.core.numerictypes import obj2sctype, integer
+from numpy.core.multiarray import dtype, _flagdict
 from numpy.core import product, ndarray
 
 __all__ = ['issubclass_', 'get_numpy_include', 'issubsctype',
@@ -76,30 +76,77 @@
     libpath = os.path.join(libdir, libname)
     return ctypes.cdll[libpath]
 
+def _num_fromflags(flaglist):
+    num = 0
+    for val in flaglist:
+        num += _flagdict[val]
+    return num
+
+def _flags_fromnum(num):
+    res = []
+    for key, value in _flagdict.items():
+        if (num & value):
+            res.append(key)
+    return res
+
 class _ndptr(object):
     def from_param(cls, obj):
         if not isinstance(obj, ndarray):
             raise TypeError("argument must be an ndarray")
         if obj.dtype != cls._dtype_:
             raise TypeError("array must have data type", cls._dtype_)
+        if cls._ndim_ and obj.ndim != cls._ndim_:
+            raise TypeError("array must have %d dimension(s)" % cls._ndim_)
+        if cls._shape_ and obj.shape != cls._shape_:
+            raise TypeError("array must have shape ", cls._shape_)
+        if cls._flags_ and ((obj.flags.num & cls._flags_) != cls._flags_):
+            raise TypeError("array must have flags ",
+                            _flags_fromnum(cls._flags_))
         return obj.ctypes
     from_param = classmethod(from_param)
 
-# Factory for a type-checking object with from_param defined
+# Factory for an array-checking object with from_param defined
 _pointer_type_cache = {}
-def ndpointer(datatype):
+def ndpointer(datatype, ndim=None, shape=None, flags=None):
     datatype = dtype(datatype)
+    num = None
+    if flags is not None:
+        if isinstance(flags, str):
+            flags = flags.split(',')
+        elif isinstance(flags, (int, integer)):
+            num = flags
+            flags = _flags_fromnum(flags)
+        if num is None:
+            flags = [x.strip().upper() for x in flags]
+            num = _num_fromflags(flags)
     try:
-        return _pointer_type_cache[datatype]
+        return _pointer_type_cache[(datatype, ndim, shape, num)]
     except KeyError:
-        pass
+        pass        
     if datatype.names:
         name = str(id(datatype))
     else:
         name = datatype.str
+    if ndim is not None:
+        name += "_%dd" % ndim
+    if shape is not None:
+        try:
+            strshape = [str(x) for x in shape]
+        except TypeError:
+            strshape = [str(shape)]
+            shape = (shape,)
+        shape = tuple(shape)
+        name += "_"+"x".join(strshape)
+    if flags is not None:
+        name += "_"+"_".join(flags)
+    else:
+        flags = []
     klass = type("ndpointer_%s"%name, (_ndptr,),
-                 {"_dtype_": datatype})
-    _pointer_type_cache[datatype] = klass
+                 {"_dtype_": datatype,
+                  "_shape_" : shape,
+                  "_ndim_" : ndim,
+                  "_flags_" : num})
+    _pointer_type_cache[datatype] = klass    
     return klass
 
 



More information about the Numpy-svn mailing list