[Numpy-svn] r6127 - in trunk/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Mon Dec 1 03:45:53 CST 2008


Author: pierregm
Date: 2008-12-01 03:45:51 -0600 (Mon, 01 Dec 2008)
New Revision: 6127

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
Fixed make_mask_descr for nested dtypes

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-11-30 15:08:38 UTC (rev 6126)
+++ trunk/numpy/ma/core.py	2008-12-01 09:45:51 UTC (rev 6127)
@@ -798,22 +798,27 @@
     Each field is set to a bool.
 
     """
+    def _make_descr(datatype):
+        "Private function allowing recursion."
+        # Do we have some name fields ?
+        names = datatype.names
+        if names:
+            descr = []
+            for name in names:
+                (ndtype, _) = datatype.fields[name]
+                descr.append((name, _make_descr(ndtype)))
+            return descr
+        # Is this some kind of composite a la (np.float,2)
+        elif datatype.subdtype:
+            mdescr = list(datatype.subdtype)
+            mdescr[0] = np.dtype(bool)
+            return tuple(mdescr)
+        else:
+            return np.bool
     # Make sure we do have a dtype
     if not isinstance(ndtype, np.dtype):
         ndtype = np.dtype(ndtype)
-    # Do we have some name fields ?
-    if ndtype.names:
-        mdescr = [list(_) for _ in ndtype.descr]
-        for m in mdescr:
-            m[1] = '|b1'
-        return np.dtype([tuple(_) for _ in mdescr])
-    # Is this some kind of composite a la (np.float,2)
-    elif ndtype.subdtype:
-        mdescr = list(ndtype.subdtype)
-        mdescr[0] = np.dtype(bool)
-        return np.dtype(tuple(mdescr))
-    else:
-        return MaskType
+    return np.dtype(_make_descr(ndtype))
 
 def get_mask(a):
     """Return the mask of a, if any, or nomask.

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-11-30 15:08:38 UTC (rev 6126)
+++ trunk/numpy/ma/tests/test_core.py	2008-12-01 09:45:51 UTC (rev 6127)
@@ -2339,8 +2339,18 @@
         ntype = np.float
         test = make_mask_descr(ntype)
         assert_equal(test, np.dtype(np.bool))
+        #
+        ntype = [('a', np.float), ('b', [('ba', np.float), ('bb', np.float)])]
+        test = make_mask_descr(ntype)
+        control = np.dtype([('a', 'b1'), ('b', [('ba', 'b1'), ('bb', 'b1')])])
+        assert_equal(test, control)
+        #
+        ntype = [('a', (np.float, 2))]
+        test = make_mask_descr(ntype)
+        assert_equal(test, np.dtype([('a', (np.bool, 2))]))
 
 
+
     def test_make_mask(self):
         "Test make_mask"
         # w/ a list as an input



More information about the Numpy-svn mailing list