[Numpy-svn] r6315 - in trunk/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Fri Jan 9 14:18:17 CST 2009


Author: pierregm
Date: 2009-01-09 14:18:12 -0600 (Fri, 09 Jan 2009)
New Revision: 6315

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
* Added flatten_structured_arrays
* Fixed _get_recordarray for nested structures

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2009-01-09 19:37:16 UTC (rev 6314)
+++ trunk/numpy/ma/core.py	2009-01-09 20:18:12 UTC (rev 6315)
@@ -1483,6 +1483,56 @@
         return d
 
 
+def flatten_structured_array(a):
+    """
+    Flatten a strutured array.
+
+    The datatype of the output is the largest datatype of the (nested) fields.
+
+    Returns
+    -------
+    output : var
+        Flatten MaskedArray if the input is a MaskedArray,
+        standard ndarray otherwise.
+
+    Examples
+    --------
+    >>> ndtype = [('a', int), ('b', float)]
+    >>> a = np.array([(1, 1), (2, 2)], dtype=ndtype)
+    >>> flatten_structured_array(a)
+    array([[1., 1.],
+           [2., 2.]])
+
+    """
+    #
+    def flatten_sequence(iterable):
+        """Flattens a compound of nested iterables."""
+        for elm in iter(iterable):
+            if hasattr(elm,'__iter__'):
+                for f in flatten_sequence(elm):
+                    yield f
+            else:
+                yield elm
+    #
+    a = np.asanyarray(a)
+    inishape = a.shape
+    a = a.ravel()
+    if isinstance(a, MaskedArray):
+        out = np.array([tuple(flatten_sequence(d.item())) for d in a._data])
+        out = out.view(MaskedArray)
+        out._mask = np.array([tuple(flatten_sequence(d.item()))
+                              for d in getmaskarray(a)])
+    else:
+        out = np.array([tuple(flatten_sequence(d.item())) for d in a])
+    if len(inishape) > 1:
+        newshape = list(out.shape)
+        newshape[0] = inishape
+        out.shape = tuple(flatten_sequence(newshape))
+    return out
+
+
+
+
 class MaskedArray(ndarray):
     """
     Arrays with possibly masked values.  Masked values of True
@@ -2021,34 +2071,28 @@
 #        return self._mask.reshape(self.shape)
         return self._mask
     mask = property(fget=_get_mask, fset=__setmask__, doc="Mask")
-    #
-    def _getrecordmask(self):
-        """Return the mask of the records.
+
+
+    def _get_recordmask(self):
+        """
+    Return the mask of the records.
     A record is masked when all the fields are masked.
 
         """
         _mask = ndarray.__getattribute__(self, '_mask').view(ndarray)
         if _mask.dtype.names is None:
             return _mask
-        if _mask.size > 1:
-            axis = 1
-        else:
-            axis = None
-        #
-        try:
-            return _mask.view((bool_, len(self.dtype))).all(axis)
-        except ValueError:
-            # In case we have nested fields...
-            return np.all([[f[n].all() for n in _mask.dtype.names]
-                           for f in _mask], axis=axis)
+        return np.all(flatten_structured_array(_mask), axis=-1)
 
-    def _setrecordmask(self):
+
+    def _set_recordmask(self):
         """Return the mask of the records.
     A record is masked when all the fields are masked.
 
         """
         raise NotImplementedError("Coming soon: setting the mask per records!")
-    recordmask = property(fget=_getrecordmask)
+    recordmask = property(fget=_get_recordmask)
+
     #............................................
     def harden_mask(self):
         """Force the mask to hard.

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2009-01-09 19:37:16 UTC (rev 6314)
+++ trunk/numpy/ma/tests/test_core.py	2009-01-09 20:18:12 UTC (rev 6315)
@@ -482,8 +482,12 @@
         test = a.filled(0)
         control = np.array([(1, (0, 1)), (2, (2, 0))], dtype=ndtype)
         assert_equal(test, control)
-        
+        #
+        test = a['B'].filled(0)
+        control = np.array([(0, 1), (2, 0)], dtype=a['B'].dtype)
+        assert_equal(test, control)
 
+
     def test_optinfo_propagation(self):
         "Checks that _optinfo dictionary isn't back-propagated"
         x = array([1,2,3,], dtype=float)
@@ -503,6 +507,45 @@
         control = "[(--, (2, --)) (4, (--, 6.0))]"
         assert_equal(str(test), control)
 
+
+    def test_flatten_structured_array(self):
+        "Test flatten_structured_array on arrays"
+        # On ndarray
+        ndtype = [('a', int), ('b', float)]
+        a = np.array([(1, 1), (2, 2)], dtype=ndtype)
+        test = flatten_structured_array(a)
+        control = np.array([[1., 1.], [2., 2.]], dtype=np.float)
+        assert_equal(test, control)
+        assert_equal(test.dtype, control.dtype)
+        # On masked_array
+        a = ma.array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
+        test = flatten_structured_array(a)
+        control = ma.array([[1., 1.], [2., 2.]], 
+                           mask=[[0, 1], [1, 0]], dtype=np.float)
+        assert_equal(test, control)
+        assert_equal(test.dtype, control.dtype)
+        assert_equal(test.mask, control.mask)
+        # On masked array with nested structure
+        ndtype = [('a', int), ('b', [('ba', int), ('bb', float)])]
+        a = ma.array([(1, (1, 1.1)), (2, (2, 2.2))],
+                     mask=[(0, (1, 0)), (1, (0, 1))], dtype=ndtype)
+        test = flatten_structured_array(a)
+        control = ma.array([[1., 1., 1.1], [2., 2., 2.2]], 
+                           mask=[[0, 1, 0], [1, 0, 1]], dtype=np.float)
+        assert_equal(test, control)
+        assert_equal(test.dtype, control.dtype)
+        assert_equal(test.mask, control.mask)
+        # Keeping the initial shape
+        ndtype = [('a', int), ('b', float)]
+        a = np.array([[(1, 1),], [(2, 2),]], dtype=ndtype)
+        test = flatten_structured_array(a)
+        control = np.array([[[1., 1.],], [[2., 2.],]], dtype=np.float)
+        assert_equal(test, control)
+        assert_equal(test.dtype, control.dtype)
+
+
+        
+
 #------------------------------------------------------------------------------
 
 class TestMaskedArrayArithmetic(TestCase):



More information about the Numpy-svn mailing list