[Numpy-svn] r6335 - in trunk/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Mon Jan 26 20:46:31 CST 2009


Author: pierregm
Date: 2009-01-26 20:46:26 -0600 (Mon, 26 Jan 2009)
New Revision: 6335

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
* prevent MaskedBinaryOperation and DomainedBinaryOperation to shrink the mask of the output when at least one of the inputs has a mask full of False

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2009-01-26 21:04:26 UTC (rev 6334)
+++ trunk/numpy/ma/core.py	2009-01-27 02:46:26 UTC (rev 6335)
@@ -616,7 +616,7 @@
 
     def __call__ (self, a, b, *args, **kwargs):
         "Execute the call behavior."
-        m = mask_or(getmask(a), getmask(b))
+        m = mask_or(getmask(a), getmask(b), shrink=False)
         (da, db) = (getdata(a), getdata(b))
         # Easy case: there's no mask...
         if m is nomask:
@@ -627,8 +627,12 @@
         # Transforms to a (subclass of) MaskedArray if we don't have a scalar
         if result.shape:
             result = result.view(get_masked_subclass(a, b))
+            # If we have a mask, make sure it's broadcasted properly
             if m.any():
                 result._mask = mask_or(getmaskarray(a), getmaskarray(b))
+            # If some initial masks where not shrunk, don't shrink the result
+            elif m.shape:
+                result._mask = make_mask_none(result.shape, result.dtype)
             if isinstance(a, MaskedArray):
                 result._update_from(a)
             if isinstance(b, MaskedArray):
@@ -754,18 +758,19 @@
     def __call__(self, a, b, *args, **kwargs):
         "Execute the call behavior."
         ma = getmask(a)
-        mb = getmask(b)
+        mb = getmaskarray(b)
         da = getdata(a)
         db = getdata(b)
         t = narray(self.domain(da, db), copy=False)
         if t.any(None):
-            mb = mask_or(mb, t)
+            mb = mask_or(mb, t, shrink=False)
             # The following line controls the domain filling
             if t.size == db.size:
                 db = np.where(t, self.filly, db)
             else:
                 db = np.where(np.resize(t, db.shape), self.filly, db)
-        m = mask_or(ma, mb)
+        # Shrink m if a.mask was nomask, otherwise don't.
+        m = mask_or(ma, mb, shrink=(getattr(a, '_mask', nomask) is nomask))
         if (not m.ndim) and m:
             return masked
         elif (m is nomask):
@@ -774,7 +779,12 @@
             result = np.where(m, da, self.f(da, db, *args, **kwargs))
         if result.shape:
             result = result.view(get_masked_subclass(a, b))
-            result._mask = m
+            # If we have a mask, make sure it's broadcasted properly
+            if m.any():
+                result._mask = mask_or(getmaskarray(a), mb)
+            # If some initial masks where not shrunk, don't shrink the result
+            elif m.shape:
+                result._mask = make_mask_none(result.shape, result.dtype)
             if isinstance(a, MaskedArray):
                 result._update_from(a)
             if isinstance(b, MaskedArray):

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2009-01-26 21:04:26 UTC (rev 6334)
+++ trunk/numpy/ma/tests/test_core.py	2009-01-27 02:46:26 UTC (rev 6335)
@@ -869,6 +869,60 @@
         assert_equal(test.mask, control.mask)
 
 
+    def test_domained_binops_d2D(self):
+        "Test domained binary operations on 2D data"
+        a = array([[1.], [2.], [3.]], mask=[[False], [True], [True]])
+        b = array([[2., 3.], [4., 5.], [6., 7.]])
+        #
+        test = a / b
+        control = array([[1./2., 1./3.], [2., 2.], [3., 3.]],
+                        mask=[[0, 0], [1, 1], [1, 1]])
+        assert_equal(test, control)
+        assert_equal(test.data, control.data)
+        assert_equal(test.mask, control.mask)
+        #
+        test = b / a
+        control = array([[2./1., 3./1.], [4., 5.], [6., 7.]],
+                        mask=[[0, 0], [1, 1], [1, 1]])
+        assert_equal(test, control)
+        assert_equal(test.data, control.data)
+        assert_equal(test.mask, control.mask)
+        #
+        a = array([[1.], [2.], [3.]])
+        b = array([[2., 3.], [4., 5.], [6., 7.]],
+                  mask=[[0, 0], [0, 0], [0, 1]])
+        test = a / b
+        control = array([[1./2, 1./3], [2./4, 2./5], [3./6, 3]],
+                        mask=[[0, 0], [0, 0], [0, 1]])
+        assert_equal(test, control)
+        assert_equal(test.data, control.data)
+        assert_equal(test.mask, control.mask)
+        #
+        test = b / a
+        control = array([[2/1., 3/1.], [4/2., 5/2.], [6/3., 7]],
+                        mask=[[0, 0], [0, 0], [0, 1]])
+        assert_equal(test, control)
+        assert_equal(test.data, control.data)
+        assert_equal(test.mask, control.mask)
+
+
+    def test_noshrinking(self):
+        "Check that we don't shrink a mask when not wanted"
+        # Binary operations
+        a = masked_array([1,2,3], mask=[False,False,False], shrink=False)
+        b = a + 1
+        assert_equal(b.mask, [0, 0, 0])
+        # In place binary operation
+        a += 1
+        assert_equal(a.mask, [0, 0, 0])
+        # Domained binary operation
+        b = a / 1.
+        assert_equal(b.mask, [0, 0, 0])
+        # In place binary operation
+        a /= 1.
+        assert_equal(a.mask, [0, 0, 0])
+        
+
     def test_mod(self):
         "Tests mod"
         (x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d



More information about the Numpy-svn mailing list