[Numpy-svn] r4750 - in branches/maskedarray/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Thu Jan 24 03:53:41 CST 2008


Author: pierregm
Date: 2008-01-24 03:53:36 -0600 (Thu, 24 Jan 2008)
New Revision: 4750

Modified:
   branches/maskedarray/numpy/ma/core.py
   branches/maskedarray/numpy/ma/tests/test_core.py
Log:
core: fixed compress to ensure that a.compress(cond)==a[cond] in most cases

Modified: branches/maskedarray/numpy/ma/core.py
===================================================================
--- branches/maskedarray/numpy/ma/core.py	2008-01-24 01:43:05 UTC (rev 4749)
+++ branches/maskedarray/numpy/ma/core.py	2008-01-24 09:53:36 UTC (rev 4750)
@@ -3121,8 +3121,7 @@
         _view = type(a)
     else:
         _view = MaskedArray
-    # Make sure the condition has no missing values
-    condition = filled(condition, False)
+    condition = condition.view(ndarray)
     #
     _new = ndarray.compress(_data, condition, axis=axis, out=out).view(_view)
     _new._update_from(a)

Modified: branches/maskedarray/numpy/ma/tests/test_core.py
===================================================================
--- branches/maskedarray/numpy/ma/tests/test_core.py	2008-01-24 01:43:05 UTC (rev 4749)
+++ branches/maskedarray/numpy/ma/tests/test_core.py	2008-01-24 09:53:36 UTC (rev 4750)
@@ -1390,20 +1390,28 @@
 
     def test_compress(self):
         "test compress"
-        a = masked_array([10, 20, 30, 40], fill_value=9999)
-        condition = (a > 15) & (a < 35)
-        assert_equal(a.compress(condition),[20,30])
+        a = masked_array([1., 2., 3., 4., 5.], fill_value=9999)
+        condition = (a > 1.5) & (a < 3.5)
+        assert_equal(a.compress(condition),[2.,3.])
         #
-        a[1] = masked
+        a[[2,3]] = masked
         b = a.compress(condition)
-        assert_equal(b._data,[20,30])
-        assert_equal(b._mask,[1,0])
+        assert_equal(b._data,[2.,3.])
+        assert_equal(b._mask,[0,1])
         assert_equal(b.fill_value,9999)
+        assert_equal(b,a[condition])
         #
+        condition = (a<4.)
+        b = a.compress(condition)
+        assert_equal(b._data,[1.,2.,3.])
+        assert_equal(b._mask,[0,0,1])
+        assert_equal(b.fill_value,9999)
+        assert_equal(b,a[condition])
+        #
         a = masked_array([[10,20,30],[40,50,60]], mask=[[0,0,1],[1,0,0]])
         b = a.compress(a.ravel() >= 22)
-        assert_equal(b._data, [50, 60])
-        assert_equal(b._mask, [0,0])
+        assert_equal(b._data, [30, 40, 50, 60])
+        assert_equal(b._mask, [1,1,0,0])
         #
         x = numpy.array([3,1,2])
         b = a.compress(x >= 2, axis=1)    
@@ -1411,7 +1419,6 @@
         assert_equal(b._mask, [[0,1],[1,0]])
 
 
-
 #..............................................................................
 
 ###############################################################################



More information about the Numpy-svn mailing list