[Scipy-svn] r5545 - branches/special_refactor/scipy/special/core/tests

scipy-svn@scip... scipy-svn@scip...
Thu Feb 12 16:10:23 CST 2009


Author: ptvirtan
Date: 2009-02-12 16:09:45 -0600 (Thu, 12 Feb 2009)
New Revision: 5545

Modified:
   branches/special_refactor/scipy/special/core/tests/test_core.py
Log:
Test also complex-valued special function variants

Modified: branches/special_refactor/scipy/special/core/tests/test_core.py
===================================================================
--- branches/special_refactor/scipy/special/core/tests/test_core.py	2009-02-12 03:38:50 UTC (rev 5544)
+++ branches/special_refactor/scipy/special/core/tests/test_core.py	2009-02-12 22:09:45 UTC (rev 5545)
@@ -27,29 +27,41 @@
 
     TESTS = [
         Data(arccosh, 'acosh_data.txt', 0, 1),
+        Data(arccosh, 'acosh_data.txt', 0j, 1, rtol=5e-14),
+        
         Data(arcsinh, 'asinh_data.txt', 0, 1),
+        Data(arcsinh, 'asinh_data.txt', 0j, 1),
+        
         Data(arctanh, 'atanh_data.txt', 0, 1),
+        Data(arctanh, 'atanh_data.txt', 0j, 1),
 
         Data(beta, 'beta_exp_data.txt', (0,1), 2),
-        Data(beta, 'beta_med_data.txt', (0,1), 2, rtol=1e-12),
+        Data(beta, 'beta_exp_data.txt', (0,1), 2),
         Data(beta, 'beta_small_data.txt', (0,1), 2),
 
         Data(cbrt, 'cbrt_data.txt', 1, 0),
 
         Data(digamma, 'digamma_data.txt', 0, 1),
+        Data(digamma, 'digamma_data.txt', 0j, 1),
         Data(digamma, 'digamma_neg_data.txt', 0, 1, rtol=1e-13),
+        Data(digamma, 'digamma_neg_data.txt', 0j, 1, rtol=1e-13),
         Data(digamma, 'digamma_root_data.txt', 0, 1, rtol=1e-12),
+        Data(digamma, 'digamma_root_data.txt', 0j, 1, rtol=1e-12),
         Data(digamma, 'digamma_small_data.txt', 0, 1),
+        Data(digamma, 'digamma_small_data.txt', 0j, 1),
 
         Data(ellipk_, 'ellint_k_data.txt', 0, 1),
         Data(ellipe_, 'ellint_e_data.txt', 0, 1),
         Data(ellipeinc_, 'ellint_e2_data.txt', (0,1), 2),
 
         Data(erf, 'erf_data.txt', 0, 1),
+        Data(erf, 'erf_data.txt', 0j, 1, rtol=1e-14),
         Data(erfc, 'erf_data.txt', 0, 2),
         Data(erf, 'erf_large_data.txt', 0, 1),
+        Data(erf, 'erf_large_data.txt', 0j, 1),
         Data(erfc, 'erf_large_data.txt', 0, 2),
         Data(erf, 'erf_small_data.txt', 0, 1),
+        Data(erf, 'erf_small_data.txt', 0j, 1),
         Data(erfc, 'erf_small_data.txt', 0, 2),
 
         Data(erfinv, 'erf_inv_data.txt', 0, 1),
@@ -57,42 +69,58 @@
         #Data(erfcinv, 'erfc_inv_big_data.txt', 0, 1),
 
         Data(exp1, 'expint_1_data.txt', 1, 2),
-        Data(expi, 'expinti_data.txt', 0, 1),
-        Data(expi, 'expinti_data_double.txt', 0, 1),
+        Data(exp1, 'expint_1_data.txt', 1j, 2, rtol=2e-9),
+        Data(expi, 'expinti_data.txt', 0, 1, param_filter=(lambda x: x>0)),
+        Data(expi, 'expinti_data_double.txt', 0, 1, param_filter=(lambda x: x>0)),
 
         Data(expn, 'expint_small_data.txt', (0,1), 2),
         Data(expn, 'expint_data.txt', (0,1), 2),
 
         Data(gamma, 'gamma_data.txt', 0, 1),
+        Data(gamma, 'gamma_data.txt', 0j, 1, rtol=2e-9),
         Data(gammaln, 'gamma_data.txt', 0, 2, rtol=5e-11),
         
         Data(log1p, 'log1p_expm1_data.txt', 0, 1),
         Data(expm1, 'log1p_expm1_data.txt', 0, 2),
 
         Data(iv, 'bessel_i_data.txt', (0,1), 2, rtol=1e-12),
-        Data(iv, 'bessel_i_int_data.txt', (0,1), 2, rtol=1e-12),
+        Data(iv, 'bessel_i_data.txt', (0,1j), 2, rtol=2e-10),
+        Data(iv, 'bessel_i_int_data.txt', (0,1), 2, rtol=1e-9),
+        Data(iv, 'bessel_i_int_data.txt', (0,1j), 2, rtol=2e-10),
 
         Data(jn, 'bessel_j_int_data.txt', (0,1), 2, rtol=1e-12),
+        Data(jn, 'bessel_j_int_data.txt', (0,1j), 2, rtol=1e-12),
         Data(jn, 'bessel_j_large_data.txt', (0,1), 2, rtol=6e-11),
+        Data(jn, 'bessel_j_large_data.txt', (0,1j), 2, rtol=6e-11),
+        
+        Data(jv, 'bessel_j_int_data.txt', (0,1), 2, rtol=1e-12),
+        Data(jv, 'bessel_j_int_data.txt', (0,1j), 2, rtol=1e-12),
         Data(jv, 'bessel_j_data.txt', (0,1), 2, rtol=1e-12),
+        Data(jv, 'bessel_j_data.txt', (0,1j), 2, rtol=1e-12),
 
         Data(kn, 'bessel_k_int_data.txt', (0,1), 2, rtol=1e-12),
+        Data(kn, 'bessel_k_int_data.txt', (0,1), 2, rtol=1e-12),
+
+        Data(kv, 'bessel_k_int_data.txt', (0,1), 2, rtol=1e-12),
+        Data(kv, 'bessel_k_int_data.txt', (0,1j), 2, rtol=1e-12),
         Data(kv, 'bessel_k_data.txt', (0,1), 2, rtol=1e-12),
+        Data(kv, 'bessel_k_data.txt', (0,1j), 2, rtol=1e-12),
 
         Data(yn, 'bessel_y01_data.txt', (0,1), 2, rtol=1e-12),
         Data(yn, 'bessel_yn_data.txt', (0,1), 2, rtol=1e-12),
+
+        Data(yv, 'bessel_yn_data.txt', (0,1), 2, rtol=1e-12),
+        Data(yv, 'bessel_yn_data.txt', (0,1j), 2, rtol=1e-12),
         Data(yv, 'bessel_yv_data.txt', (0,1), 2, rtol=1e-12),
+        Data(yv, 'bessel_yv_data.txt', (0,1j), 2, rtol=1e-10),
 
-        Data(zeta_, 'zeta_data.txt', 0, 1),
-        Data(zeta_, 'zeta_neg_data.txt', 0, 1),
-        Data(zeta_, 'zeta_1_up_data.txt', 0, 1),
-        Data(zeta_, 'zeta_1_below_data.txt', 0, 1),
+        Data(zeta_, 'zeta_data.txt', 0, 1, param_filter=(lambda s: s > 1)),
+        Data(zeta_, 'zeta_neg_data.txt', 0, 1, param_filter=(lambda s: s > 1)),
+        Data(zeta_, 'zeta_1_up_data.txt', 0, 1, param_filter=(lambda s: s > 1)),
+        Data(zeta_, 'zeta_1_below_data.txt', 0, 1, param_filter=(lambda s: s > 1)),
 
         # -- not used yet:
         # assoc_legendre_p.txt
-        # beta_exp_data.txt
-        # beta_med_data.txt
-        # beta_small_data.txt
         # binomial_data.txt
         # binomial_large_data.txt
         # binomial_quantile_data.txt
@@ -152,9 +180,34 @@
     """Boost test"""
     test.check(dtype=dtype)
 
+
+#------------------------------------------------------------------------------
+
 class Data(object):
+    """
+    Data set for checking a special function.
+
+    Parameters
+    ----------
+    func : function
+    filename : str
+    param_columns : int or tuple of ints
+        Columns indices in which the parameters to `func` lie.
+        Can be imaginary integers to indicate that the parameter
+        should be cast to complex.
+    result_columns : int or tuple of ints
+        Column indices for expected results from `func`.
+    rtol : float
+        Required relative tolerance
+    atol : float
+        Required absolute tolerance
+    param_filter : function, or tuple of functions/Nones
+        Filter functions to exclude some parameter ranges.
+
+    """
+    
     def __init__(self, func, filename, param_columns, result_columns,
-                 rtol=None, atol=None):
+                 rtol=None, atol=None, param_filter=None):
         self.func = func
         self.filename = os.path.join(DATA_DIR, filename)
         if not hasattr(param_columns, '__len__'):
@@ -165,6 +218,9 @@
         self.result_columns = tuple(result_columns)
         self.rtol = rtol
         self.atol = atol
+        if not hasattr(param_filter, '__len__'):
+            param_filter = (param_filter,)
+        self.param_filter = param_filter
 
     def get_tolerances(self, dtype):
         info = np.finfo(dtype)
@@ -177,6 +233,7 @@
 
     @staticmethod
     def load_data(filename, dtype):
+        """Load table data from a file; similar to np.loadtxt, but faster"""
         f = open(filename, 'r')
         try:
             ncols = 1
@@ -193,18 +250,38 @@
         return data
 
     def check(self, data=None, dtype=np.double):
+        """Check the special function against the data."""
+
         if data is None:
             data = Data.load_data(self.filename, dtype)
 
         rtol, atol = self.get_tolerances(dtype)
 
-        params = tuple([data[:,j] for j in self.param_columns])
+        # Apply given filter functions
+        if self.param_filter:
+            param_mask = np.ones((data.shape[0],), np.bool_)
+            for j, filter in zip(self.param_columns, self.param_filter):
+                if filter:
+                    param_mask &= filter(data[:,j])
+            data = data[param_mask]
+
+        # Pick parameters and results from the correct columns
+        params = []
+        for j in self.param_columns:
+            if np.iscomplexobj(j):
+                j = int(j.imag)
+                params.append(data[:,j].astype(np.complex))
+            else:
+                params.append(data[:,j])
         wanted = tuple([data[:,j] for j in self.result_columns])
+
+        # Evaluate
         got = self.func(*params)
-
         if not isinstance(got, tuple):
             got = (got,)
 
+        # Check the validity of each output returned
+
         assert len(got) == len(wanted)
 
         for output_num, (x, y) in enumerate(zip(got, wanted)):
@@ -231,6 +308,7 @@
             bad_j = ~(tol_mask & pinf_mask & minf_mask & nan_mask)
 
             if np.any(bad_j):
+                # Some bad results: inform what, where, and how bad
                 msg = [""]
                 msg.append("Max |adiff|: %g" % diff.max())
                 msg.append("Max |rdiff|: %g" % rdiff.max())
@@ -242,9 +320,15 @@
                     a = "  ".join(map(fmt, params))
                     b = "  ".join(map(fmt, got))
                     c = "  ".join(map(fmt, wanted))
-                    msg.append("%s => %s != %s" % (a, b, c))
+                    d = fmt(rdiff)
+                    msg.append("%s => %s != %s  (rdiff %s)" % (a, b, c, d))
                 assert False, "\n".join(msg)
 
     def __repr__(self):
-        return "<Boost test for %s: %s>" % (self.func.__name__,
-                                            os.path.basename(self.filename))
+        """Pretty-printing, esp. for Nose output"""
+        if np.any(map(np.iscomplexobj, self.param_columns)):
+            is_complex = " (complex)"
+        else:
+            is_complex = ""
+        return "<Boost test for %s%s: %s>" % (self.func.__name__, is_complex,
+                                              os.path.basename(self.filename))



More information about the Scipy-svn mailing list