# [Numpy-svn] r4984 - trunk/numpy/testing

numpy-svn@scip... numpy-svn@scip...
Mon Apr 7 19:09:14 CDT 2008

```Author: cdavid
Date: 2008-04-07 19:09:12 -0500 (Mon, 07 Apr 2008)
New Revision: 4984

Modified:
trunk/numpy/testing/utils.py
Log:
Handle nan in assert_array* funcs correctly. All numpy tests pass

Modified: trunk/numpy/testing/utils.py
===================================================================
--- trunk/numpy/testing/utils.py	2008-04-08 00:06:57 UTC (rev 4983)
+++ trunk/numpy/testing/utils.py	2008-04-08 00:09:12 UTC (rev 4984)
@@ -186,9 +186,14 @@

def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
-    from numpy.core import asarray
+    from numpy.core import asarray, isnan, any
+    from numpy import isreal, iscomplex
x = asarray(x)
y = asarray(y)
+
+    def isnumber(x):
+        return x.dtype.char in '?bhilqpBHILQPfdgFDG'
+
try:
cond = (x.shape==() or y.shape==()) or x.shape == y.shape
if not cond:
@@ -199,7 +204,25 @@
names=('x', 'y'))
assert cond, msg
-        val = comparison(x,y)
+
+        if (isnumber(x) and isnumber(y)) and (any(isnan(x)) or any(isnan(y))):
+            # Handling nan: we first check that x and y have the nan at the
+            # same locations, and then we mask the nan and do the comparison as
+            # usual.
+            xnanid = isnan(x)
+            ynanid = isnan(y)
+            try:
+                assert_array_equal(xnanid, ynanid)
+            except AssertionError:
+                msg = build_err_msg([x, y],
+                                    err_msg
+                                    + '\n(x and y nan location mismatch %s, '
+                                    + '%s mismatch)' % (xnanid, ynanid),