[Scipy-svn] r6982 - trunk/scipy/signal/tests

scipy-svn@scip... scipy-svn@scip...
Mon Nov 29 17:41:08 CST 2010


Author: warren.weckesser
Date: 2010-11-29 17:41:08 -0600 (Mon, 29 Nov 2010)
New Revision: 6982

Modified:
   trunk/scipy/signal/tests/test_signaltools.py
Log:
TST: signal: make the decimal precision of the complex tests of the correlate function depend on the data dtype.

Modified: trunk/scipy/signal/tests/test_signaltools.py
===================================================================
--- trunk/scipy/signal/tests/test_signaltools.py	2010-11-29 14:57:55 UTC (rev 6981)
+++ trunk/scipy/signal/tests/test_signaltools.py	2010-11-29 23:41:08 UTC (rev 6982)
@@ -1,4 +1,4 @@
-#this program corresponds to special.py
+
 from decimal import Decimal
 
 from numpy.testing import TestCase, run_module_suite, assert_equal, \
@@ -12,6 +12,7 @@
 from numpy import array, arange
 import numpy as np
 
+
 class _TestConvolve(TestCase):
     def test_basic(self):
         a = [3,4,5,6,5,4]
@@ -293,7 +294,10 @@
 class TestWiener(TestCase):
     def test_basic(self):
         g = array([[5,6,4,3],[3,5,6,2],[2,3,5,6],[1,6,9,7]],'d')
-        correct = array([[2.16374269,3.2222222222, 2.8888888889, 1.6666666667],[2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],[2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],[1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
+        correct = array([[2.16374269,3.2222222222, 2.8888888889, 1.6666666667],
+                         [2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],
+                         [2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],
+                         [1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
         h = signal.wiener(g)
         assert_array_almost_equal(h,correct,decimal=6)
 
@@ -449,8 +453,11 @@
 class TestLinearFilterDecimal(_TestLinearFilter):
     dt = np.dtype(Decimal)
 
+
 class _TestCorrelateReal(TestCase):
+
     dt = None
+
     def _setup_rank1(self):
         # a.size should be greated than b.size for the tests
         a = np.linspace(0, 3, 4).astype(self.dt)
@@ -568,6 +575,7 @@
         assert_array_almost_equal(y, y_r)
         self.assertTrue(y.dtype == self.dt)
 
+
 def _get_testcorrelate_class(i, base):
     class TestCorrelateX(base):
         dt = i
@@ -580,9 +588,19 @@
     cls = _get_testcorrelate_class(i, _TestCorrelateReal)
     globals()[cls.__name__] = cls
 
+
 class _TestCorrelateComplex(TestCase):
+
+    # The numpy data type to use.
     dt = None
+    
+    # The decimal precision to be used for comparing results.
+    # This value will be passed as the 'decimal' keyword argument of
+    # assert_array_almost_equal().
+    decimal = None
+
     def _setup_rank1(self, mode):
+        np.random.seed(9)
         a = np.random.randn(10).astype(self.dt)
         a += 1j * np.random.randn(10).astype(self.dt)
         b = np.random.randn(8).astype(self.dt)
@@ -597,19 +615,19 @@
     def test_rank1_valid(self):
         a, b, y_r = self._setup_rank1('valid')
         y = correlate(a, b, 'valid', old_behavior=False)
-        assert_array_almost_equal(y, y_r)
+        assert_array_almost_equal(y, y_r, decimal=self.decimal)
         self.assertTrue(y.dtype == self.dt)
 
     def test_rank1_same(self):
         a, b, y_r = self._setup_rank1('same')
         y = correlate(a, b, 'same', old_behavior=False)
-        assert_array_almost_equal(y, y_r)
+        assert_array_almost_equal(y, y_r, decimal=self.decimal)
         self.assertTrue(y.dtype == self.dt)
 
     def test_rank1_full(self):
         a, b, y_r = self._setup_rank1('full')
         y = correlate(a, b, 'full', old_behavior=False)
-        assert_array_almost_equal(y, y_r)
+        assert_array_almost_equal(y, y_r, decimal=self.decimal)
         self.assertTrue(y.dtype == self.dt)
 
     def test_rank3(self):
@@ -624,28 +642,28 @@
                 correlate(a.imag, b.real, old_behavior=False))
 
         y = correlate(a, b, 'full', old_behavior=False)
-        assert_array_almost_equal(y, y_r, decimal=4)
+        assert_array_almost_equal(y, y_r, decimal=self.decimal-1)
         self.assertTrue(y.dtype == self.dt)
 
     @dec.deprecated()
     def test_rank1_valid_old(self):
         a, b, y_r = self._setup_rank1('valid')
         y = correlate(b, a.conj(), 'valid')
-        assert_array_almost_equal(y, y_r)
+        assert_array_almost_equal(y, y_r, decimal=self.decimal)
         self.assertTrue(y.dtype == self.dt)
 
     @dec.deprecated()
     def test_rank1_same_old(self):
         a, b, y_r = self._setup_rank1('same')
         y = correlate(b, a.conj(), 'same')
-        assert_array_almost_equal(y, y_r)
+        assert_array_almost_equal(y, y_r, decimal=self.decimal)
         self.assertTrue(y.dtype == self.dt)
 
     @dec.deprecated()
     def test_rank1_full_old(self):
         a, b, y_r = self._setup_rank1('full')
         y = correlate(b, a.conj(), 'full')
-        assert_array_almost_equal(y, y_r)
+        assert_array_almost_equal(y, y_r, decimal=self.decimal)
         self.assertTrue(y.dtype == self.dt)
 
     @dec.deprecated()
@@ -661,13 +679,20 @@
                 correlate(a.imag, b.real, old_behavior=False))
 
         y = correlate(b, a.conj(), 'full')
-        assert_array_almost_equal(y, y_r, decimal=4)
+        assert_array_almost_equal(y, y_r, decimal=self.decimal-1)
         self.assertTrue(y.dtype == self.dt)
 
-for i in [np.csingle, np.cdouble, np.clongdouble]:
+
+# Create three classes, one for each complex data type: TestCorrelateComplex64,
+# TestCorrelateComplex128 and TestCorrelateComplex256.
+# The second number in the pairs is used in the 'decimal' keyword argument of
+# the  array comparisons in the tests.
+for i, decimal in [(np.csingle, 5), (np.cdouble, 10), (np.clongdouble, 15)]:
     cls = _get_testcorrelate_class(i, _TestCorrelateComplex)
+    cls.decimal = decimal
     globals()[cls.__name__] = cls
 
+
 class TestFiltFilt:
     def test_basic(self):
         out = signal.filtfilt([1,2,3], [1,2,3], np.arange(12))



More information about the Scipy-svn mailing list