[Numpy-svn] r8467 - in trunk/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Tue Jun 29 12:57:15 CDT 2010


Author: pierregm
Date: 2010-06-29 12:57:15 -0500 (Tue, 29 Jun 2010)
New Revision: 8467

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
Fixed __eq__/__ne__ for scalars

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2010-06-27 11:09:58 UTC (rev 8466)
+++ trunk/numpy/ma/core.py	2010-06-29 17:57:15 UTC (rev 8467)
@@ -3584,8 +3584,13 @@
             return masked
         omask = getattr(other, '_mask', nomask)
         if omask is nomask:
-            check = ndarray.__eq__(self.filled(0), other).view(type(self))
-            check._mask = self._mask
+            check = ndarray.__eq__(self.filled(0), other)
+            try:
+                check = check.view(type(self))
+                check._mask = self._mask
+            except AttributeError:
+                # Dang, we have a bool instead of an array: return the bool
+                return check
         else:
             odata = filled(other, 0)
             check = ndarray.__eq__(self.filled(0), odata).view(type(self))
@@ -3612,8 +3617,13 @@
             return masked
         omask = getattr(other, '_mask', nomask)
         if omask is nomask:
-            check = ndarray.__ne__(self.filled(0), other).view(type(self))
-            check._mask = self._mask
+            check = ndarray.__ne__(self.filled(0), other)
+            try:
+                check = check.view(type(self))
+                check._mask = self._mask
+            except AttributeError:
+                # In case check is a boolean (or a numpy.bool)
+                return check
         else:
             odata = filled(other, 0)
             check = ndarray.__ne__(self.filled(0), odata).view(type(self))

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2010-06-27 11:09:58 UTC (rev 8466)
+++ trunk/numpy/ma/tests/test_core.py	2010-06-29 17:57:15 UTC (rev 8467)
@@ -1149,6 +1149,21 @@
         assert_equal(test.mask, [False, False])
 
 
+    def test_eq_w_None(self):
+        a = array([1, 2], mask=False)
+        assert_equal(a == None, False)
+        assert_equal(a != None, True)
+        a = masked
+        assert_equal(a == None, masked)
+
+    def test_eq_w_scalar(self):
+        a = array(1)
+        assert_equal(a == 1, True)
+        assert_equal(a == 0, False)
+        assert_equal(a != 1, False)
+        assert_equal(a != 0, True)
+
+
     def test_numpyarithmetics(self):
         "Check that the mask is not back-propagated when using numpy functions"
         a = masked_array([-1, 0, 1, 2, 3], mask=[0, 0, 0, 0, 1])



More information about the Numpy-svn mailing list