[Scipy-svn] r6551 - branches/0.8.x/scipy/special/tests

scipy-svn@scip... scipy-svn@scip...
Sun Jun 20 18:49:43 CDT 2010


Author: ptvirtan
Date: 2010-06-20 18:49:43 -0500 (Sun, 20 Jun 2010)
New Revision: 6551

Modified:
   branches/0.8.x/scipy/special/tests/testutils.py
Log:
ENH: special: add assert_func_equal to clean up the way dataset tests are implemented

(cherry picked from commit r6547)

Modified: branches/0.8.x/scipy/special/tests/testutils.py
===================================================================
--- branches/0.8.x/scipy/special/tests/testutils.py	2010-06-20 23:47:33 UTC (rev 6550)
+++ branches/0.8.x/scipy/special/tests/testutils.py	2010-06-20 23:49:43 UTC (rev 6551)
@@ -6,6 +6,9 @@
 
 import scipy.special as sc
 
+__all__ = ['with_special_errors', 'assert_tol_equal', 'assert_func_equal',
+           'FuncData']
+
 #------------------------------------------------------------------------------
 # Enable convergence and loss of precision warnings -- turn off one by one
 #------------------------------------------------------------------------------
@@ -46,6 +49,39 @@
 # error reports
 #------------------------------------------------------------------------------
 
+def assert_func_equal(func, results, points, rtol=None, atol=None,
+                      param_filter=None, knownfailure=None,
+                      vectorized=True, dtype=None):
+    if hasattr(points, 'next'):
+        # it's a generator
+        points = list(points)
+
+    points = np.asarray(points)
+    if points.ndim == 1:
+        points = points[:,None]
+
+    if hasattr(results, '__name__'):
+        # function
+        if vectorized:
+            results = results(*tuple(points.T))
+        else:
+            results = np.array([results(*tuple(p)) for p in points])
+            if results.dtype == object:
+                try:
+                    results = results.astype(float)
+                except TypeError:
+                    results = results.astype(complex)
+    else:
+        results = np.asarray(results)
+
+    npoints = points.shape[1]
+
+    data = np.c_[points, results]
+    fdata = FuncData(func, data, range(npoints), range(npoints, data.shape[1]),
+                     rtol=rtol, atol=atol, param_filter=param_filter,
+                     knownfailure=knownfailure)
+    fdata.check()
+
 class FuncData(object):
     """
     Data set for checking a special function.



More information about the Scipy-svn mailing list