[Numpy-svn] r5114 - in trunk/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Wed Apr 30 14:36:45 CDT 2008


Author: pierregm
Date: 2008-04-30 14:36:42 -0500 (Wed, 30 Apr 2008)
New Revision: 5114

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
   trunk/numpy/ma/testutils.py
Log:
core      : fixed a bug w/ array((0,0))/0.
testutils : introduced assert_almost_equal/assert_approx_equal: 
			use assert_almost_equal(a,b,decimal) to compare a and b up to decimal places
			use assert_approx_equal(a,b,decimal) to compare a and b up to b*10.**-decimal

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-04-30 16:01:25 UTC (rev 5113)
+++ trunk/numpy/ma/core.py	2008-04-30 19:36:42 UTC (rev 5114)
@@ -613,9 +613,8 @@
         t = narray(self.domain(d1, d2), copy=False)
         if t.any(None):
             mb = mask_or(mb, t)
-            # The following two lines control the domain filling
-            d2 = d2.copy()
-            numpy.putmask(d2, t, self.filly)
+            # The following line controls the domain filling
+            d2 = numpy.where(t,self.filly,d2)
         m = mask_or(ma, mb)
         if (not m.ndim) and m:
             return masked

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-04-30 16:01:25 UTC (rev 5113)
+++ trunk/numpy/ma/tests/test_core.py	2008-04-30 19:36:42 UTC (rev 5114)
@@ -257,6 +257,10 @@
         x = array(0, mask=0)
         assert_equal(x.filled().ctypes.data, x.ctypes.data)
         assert_equal(str(xm), str(masked_print_option))
+        # Make sure we don't lose the shape in some circumstances
+        xm = array((0,0))/0.
+        assert_equal(xm.shape,(2,))
+        assert_equal(xm.mask,[1,1])        
     #.........................
     def test_basic_ufuncs (self):
         "Test various functions such as sin, cos."

Modified: trunk/numpy/ma/testutils.py
===================================================================
--- trunk/numpy/ma/testutils.py	2008-04-30 16:01:25 UTC (rev 5113)
+++ trunk/numpy/ma/testutils.py	2008-04-30 19:36:42 UTC (rev 5114)
@@ -40,6 +40,23 @@
     y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
     d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
     return d.ravel()
+
+def almost(a, b, decimal=6, fill_value=True):
+    """Returns True if a and b are equal up to decimal places.
+If fill_value is True, masked values considered equal. Otherwise, masked values
+are considered unequal.
+    """
+    m = mask_or(getmask(a), getmask(b))
+    d1 = filled(a)
+    d2 = filled(b)
+    if d1.dtype.char == "O" or d2.dtype.char == "O":
+        return N.equal(d1,d2).ravel()
+    x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
+    y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
+    d = N.around(N.abs(x-y),decimal) <= 10.0**(-decimal)
+    return d.ravel()
+    
+
 #................................................
 def _assert_equal_on_sequences(actual, desired, err_msg=''):
     "Asserts the equality of two non-array sequences."
@@ -191,7 +208,7 @@
     assert_array_compare(compare, x, y, err_msg=err_msg,
                          header='Arrays are not equal')
 #............................
-def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
+def assert_array_approx_equal(x, y, decimal=6, err_msg=''):
     """Checks the elementwise equality of two masked arrays, up to a given
     number of decimals."""
     def compare(x, y):
@@ -200,6 +217,15 @@
     assert_array_compare(compare, x, y, err_msg=err_msg,
                          header='Arrays are not almost equal')
 #............................
+def assert_array_almost_equal(x, y, decimal=6, err_msg=''):
+    """Checks the elementwise equality of two masked arrays, up to a given
+    number of decimals."""
+    def compare(x, y):
+        "Returns the result of the loose comparison between x and y)."
+        return almost(x,y,decimal)
+    assert_array_compare(compare, x, y, err_msg=err_msg,
+                         header='Arrays are not almost equal')
+#............................
 def assert_array_less(x, y, err_msg=''):
     "Checks that x is smaller than y elementwise."
     assert_array_compare(less, x, y, err_msg=err_msg,



More information about the Numpy-svn mailing list