[Numpy-svn] r5632 - in trunk/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Tue Aug 12 16:12:17 CDT 2008


Author: pierregm
Date: 2008-08-12 16:12:14 -0500 (Tue, 12 Aug 2008)
New Revision: 5632

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
* masked_where : force a consistency check on the shapes of the inputs

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-08-12 18:46:31 UTC (rev 5631)
+++ trunk/numpy/ma/core.py	2008-08-12 21:12:14 UTC (rev 5632)
@@ -884,6 +884,11 @@
     """
     cond = make_mask(condition)
     a = np.array(a, copy=copy, subok=True)
+    
+    (cshape, ashape) = (cond.shape, a.shape)
+    if cshape and cshape != ashape:
+        raise IndexError("Inconsistant shape between the condition and the input"\
+                         " (got %s and %s)" % (cshape, ashape))
     if hasattr(a, '_mask'):
         cond = mask_or(cond, a._mask)
         cls = type(a)

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-08-12 18:46:31 UTC (rev 5631)
+++ trunk/numpy/ma/tests/test_core.py	2008-08-12 21:12:14 UTC (rev 5632)
@@ -1969,7 +1969,18 @@
         ctest = masked_where(btest,atest)
         assert_equal(atest,ctest)
 
+    def test_masked_where_shape_constraint(self):
+        a = arange(10)
+        try:
+            test = masked_equal(1, a)
+        except IndexError:
+            pass
+        else:
+            raise AssertionError("Should have failed...")
+        test = masked_equal(a,1)
+        assert(test.mask, [0,1,0,0,0,0,0,0,0,0])
 
+
     def test_masked_otherfunctions(self):
         assert_equal(masked_inside(range(5), 1, 3), [0, 199, 199, 199, 4])
         assert_equal(masked_outside(range(5), 1, 3),[199,1,2,3,199])



More information about the Numpy-svn mailing list