[Numpy-svn] r6305 - in trunk/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Thu Jan 8 14:02:31 CST 2009


Author: pierregm
Date: 2009-01-08 14:02:29 -0600 (Thu, 08 Jan 2009)
New Revision: 6305

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
* Add __eq__ and __ne__ for support of flexible arrays.
* Fixed .filled for nested structures

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2009-01-08 19:22:21 UTC (rev 6304)
+++ trunk/numpy/ma/core.py	2009-01-08 20:02:29 UTC (rev 6305)
@@ -857,6 +857,7 @@
 #####--------------------------------------------------------------------------
 #---- --- Mask creation functions ---
 #####--------------------------------------------------------------------------
+
 def _recursive_make_descr(datatype, newtype=bool_):
     "Private function allowing recursion in make_descr."
     # Do we have some name fields ?
@@ -1134,6 +1135,7 @@
     result._mask = cond
     return result
 
+
 def masked_greater(x, value, copy=True):
     """
     Return the array `x` masked where (x > value).
@@ -1142,22 +1144,27 @@
     """
     return masked_where(greater(x, value), x, copy=copy)
 
+
 def masked_greater_equal(x, value, copy=True):
     "Shortcut to masked_where, with condition = (x >= value)."
     return masked_where(greater_equal(x, value), x, copy=copy)
 
+
 def masked_less(x, value, copy=True):
     "Shortcut to masked_where, with condition = (x < value)."
     return masked_where(less(x, value), x, copy=copy)
 
+
 def masked_less_equal(x, value, copy=True):
     "Shortcut to masked_where, with condition = (x <= value)."
     return masked_where(less_equal(x, value), x, copy=copy)
 
+
 def masked_not_equal(x, value, copy=True):
     "Shortcut to masked_where, with condition = (x != value)."
     return masked_where(not_equal(x, value), x, copy=copy)
 
+
 def masked_equal(x, value, copy=True):
     """
     Shortcut to masked_where, with condition = (x == value).  For
@@ -1171,6 +1178,7 @@
     # return array(d, mask=m, copy=copy)
     return masked_where(equal(x, value), x, copy=copy)
 
+
 def masked_inside(x, v1, v2, copy=True):
     """
     Shortcut to masked_where, where ``condition`` is True for x inside
@@ -1188,6 +1196,7 @@
     condition = (xf >= v1) & (xf <= v2)
     return masked_where(condition, x, copy=copy)
 
+
 def masked_outside(x, v1, v2, copy=True):
     """
     Shortcut to ``masked_where``, where ``condition`` is True for x outside
@@ -1205,7 +1214,7 @@
     condition = (xf < v1) | (xf > v2)
     return masked_where(condition, x, copy=copy)
 
-#
+
 def masked_object(x, value, copy=True, shrink=True):
     """
     Mask the array `x` where the data are exactly equal to value.
@@ -1234,6 +1243,7 @@
     mask = mask_or(mask, make_mask(condition, shrink=shrink))
     return masked_array(x, mask=mask, copy=copy, fill_value=value)
 
+
 def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True):
     """
     Mask the array x where the data are approximately equal in
@@ -1271,6 +1281,7 @@
     mask = mask_or(mask, make_mask(condition, shrink=shrink))
     return masked_array(xnew, mask=mask, copy=copy, fill_value=value)
 
+
 def masked_invalid(a, copy=True):
     """
     Mask the array for invalid values (NaNs or infs).
@@ -1292,6 +1303,7 @@
 #####--------------------------------------------------------------------------
 #---- --- Printing options ---
 #####--------------------------------------------------------------------------
+
 class _MaskedPrintOption:
     """
     Handle the string used to represent missing data in a masked array.
@@ -1372,6 +1384,20 @@
 #---- --- MaskedArray class ---
 #####--------------------------------------------------------------------------
 
+def _recursive_filled(a, mask, fill_value):
+    """
+    Recursively fill `a` with `fill_value`.
+    Private function
+    """
+    names = a.dtype.names
+    for name in names:
+        current = a[name]
+        print "Name: %s : %s" % (name, current)
+        if current.dtype.names:
+            _recursive_filled(current, mask[name], fill_value[name])
+        else:
+            np.putmask(current, mask[name], fill_value[name])
+
 #...............................................................................
 class _arraymethod(object):
     """
@@ -2013,6 +2039,7 @@
         try:
             return _mask.view((bool_, len(self.dtype))).all(axis)
         except ValueError:
+            # In case we have nested fields...
             return np.all([[f[n].all() for n in _mask.dtype.names]
                            for f in _mask], axis=axis)
 
@@ -2106,6 +2133,7 @@
     fill_value = property(fget=get_fill_value, fset=set_fill_value,
                           doc="Filling value.")
 
+
     def filled(self, fill_value=None):
         """Return a copy of self._data, where masked values are filled
         with fill_value.
@@ -2140,9 +2168,10 @@
         #
         if m.dtype.names:
             result = self._data.copy()
-            for n in result.dtype.names:
-                field = result[n]
-                np.putmask(field, self._mask[n], fill_value[n])
+            _recursive_filled(result, self._mask, fill_value)
+#            for n in result.dtype.names:
+#                field = result[n]
+#                np.putmask(field, self._mask[n], fill_value[n])
         elif not m.any():
             return self._data
         else:
@@ -2287,6 +2316,58 @@
             return _print_templates['short'] % parameters
         return _print_templates['long'] % parameters
     #............................................
+    def __eq__(self, other):
+        "Check whether other equals self elementwise"
+        omask = getattr(other, '_mask', nomask)
+        if omask is nomask:
+            check = ndarray.__eq__(self.filled(0), other).view(type(self))
+            check._mask = self._mask
+        else:
+            odata = filled(other, 0)
+            check = ndarray.__eq__(self.filled(0), odata).view(type(self))
+            if self._mask is nomask:
+                check._mask = omask
+            else:
+                mask = mask_or(self._mask, omask)
+                if mask.dtype.names:
+                    if mask.size > 1:
+                        axis = 1
+                    else:
+                        axis = None
+                    try:
+                        mask = mask.view((bool_, len(self.dtype))).all(axis)
+                    except ValueError:
+                        mask =  np.all([[f[n].all() for n in mask.dtype.names]
+                                        for f in mask], axis=axis)
+                check._mask = mask
+        return check
+    #
+    def __ne__(self, other):
+        "Check whether other doesn't equal self elementwise"
+        omask = getattr(other, '_mask', nomask)
+        if omask is nomask:
+            check = ndarray.__ne__(self.filled(0), other).view(type(self))
+            check._mask = self._mask
+        else:
+            odata = filled(other, 0)
+            check = ndarray.__ne__(self.filled(0), odata).view(type(self))
+            if self._mask is nomask:
+                check._mask = omask
+            else:
+                mask = mask_or(self._mask, omask)
+                if mask.dtype.names:
+                    if mask.size > 1:
+                        axis = 1
+                    else:
+                        axis = None
+                    try:
+                        mask = mask.view((bool_, len(self.dtype))).all(axis)
+                    except ValueError:
+                        mask =  np.all([[f[n].all() for n in mask.dtype.names]
+                                        for f in mask], axis=axis)
+                check._mask = mask
+        return check
+    #
     def __add__(self, other):
         "Add other to self, and return a new masked array."
         return add(self, other)

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2009-01-08 19:22:21 UTC (rev 6304)
+++ trunk/numpy/ma/tests/test_core.py	2009-01-08 20:02:29 UTC (rev 6305)
@@ -474,6 +474,16 @@
                      np.array([(1, '1', 1.)], dtype=flexi.dtype))
 
 
+    def test_filled_w_nested_dtype(self):
+        "Test filled w/ nested dtype"
+        ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
+        a = array([(1, (1, 1)), (2, (2, 2))],
+                  mask=[(0, (1, 0)), (0, (0, 1))], dtype=ndtype)
+        test = a.filled(0)
+        control = np.array([(1, (0, 1)), (2, (2, 0))], dtype=ndtype)
+        assert_equal(test, control)
+        
+
     def test_optinfo_propagation(self):
         "Checks that _optinfo dictionary isn't back-propagated"
         x = array([1,2,3,], dtype=float)
@@ -884,6 +894,40 @@
             self.failUnless(output[0] is masked)
 
 
+    def test_eq_on_structured(self):
+        "Test the equality of structured arrays"
+        ndtype = [('A', int), ('B', int)]
+        a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
+        test = (a == a)
+        assert_equal(test, [True, True])
+        assert_equal(test.mask, [False, False])
+        b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
+        test = (a == b)
+        assert_equal(test, [False, True])
+        assert_equal(test.mask, [True, False])
+        b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
+        test = (a == b)
+        assert_equal(test, [True, False])
+        assert_equal(test.mask, [False, False])
+
+
+    def test_ne_on_structured(self):
+        "Test the equality of structured arrays"
+        ndtype = [('A', int), ('B', int)]
+        a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
+        test = (a != a)
+        assert_equal(test, [False, False])
+        assert_equal(test.mask, [False, False])
+        b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
+        test = (a != b)
+        assert_equal(test, [True, False])
+        assert_equal(test.mask, [True, False])
+        b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
+        test = (a != b)
+        assert_equal(test, [False, True])
+        assert_equal(test.mask, [False, False])
+
+
 #------------------------------------------------------------------------------
 
 class TestMaskedArrayAttributes(TestCase):



More information about the Numpy-svn mailing list