[Scipy-svn] r5498 - trunk/scipy/fftpack/tests

scipy-svn@scip... scipy-svn@scip...
Mon Jan 19 02:32:55 CST 2009


Author: cdavid
Date: 2009-01-19 02:32:50 -0600 (Mon, 19 Jan 2009)
New Revision: 5498

Modified:
   trunk/scipy/fftpack/tests/test_real_transforms.py
Log:
Refactor DCT tests so that the core code is independent of the DCT type.

Modified: trunk/scipy/fftpack/tests/test_real_transforms.py
===================================================================
--- trunk/scipy/fftpack/tests/test_real_transforms.py	2009-01-19 08:32:32 UTC (rev 5497)
+++ trunk/scipy/fftpack/tests/test_real_transforms.py	2009-01-19 08:32:50 UTC (rev 5498)
@@ -21,6 +21,8 @@
 FFTWDATA_SINGLE = np.load(join(dirname(__file__), 'fftw_single_ref.npz'))
 FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']
 
+TYPE2DCT = {1: dct1, 2: dct2, 3: dct3}
+
 def fftw_ref(type, size, dt):
     x = np.linspace(0, size-1, size).astype(dt)
     if dt == np.double:
@@ -32,15 +34,17 @@
     y = (data['dct_%d_%d' % (type, size)]).astype(dt)
     return x, y
 
-class _TestDCTIIBase(TestCase):
+class _TestDCTBase(TestCase):
     def setUp(self):
         self.rdt = None
         self.dec = 14
+        self.type = None
+        self.func = None
 
     def test_definition(self):
         for i in FFTWDATA_SIZES:
-            x, yr = fftw_ref(2, i, self.rdt)
-            y = dct2(x)
+            x, yr = fftw_ref(self.type, i, self.rdt)
+            y = self.func(x)
             self.failUnless(y.dtype == self.rdt,
                     "Output dtype is %s, expected %s" % (y.dtype, self.rdt))
             # XXX: we divide by np.max(y) because the tests fail otherwise. We
@@ -50,53 +54,31 @@
             assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec, 
                     err_msg="Size %d failed" % i)
 
-    def test_definition_ortho(self):
-        """Test orthornomal mode."""
-        for i in range(len(X)):
-            x = np.array(X[i], dtype=self.rdt)
-            yr = Y[i]
-            y = dct2(x, norm="ortho")
-            self.failUnless(y.dtype == self.rdt,
-                    "Output dtype is %s, expected %s" % (y.dtype, self.rdt))
-            assert_array_almost_equal(y, yr, decimal=self.dec)
-
     def test_axis(self):
         nt = 2
         for i in [7, 8, 9, 16, 32, 64]:
             x = np.random.randn(nt, i)
-            y = dct2(x)
+            y = self.func(x)
             for j in range(nt):
-                assert_array_almost_equal(y[j], dct2(x[j]), decimal=self.dec)
+                assert_array_almost_equal(y[j], self.func(x[j]), decimal=self.dec)
 
             x = x.T
-            y = dct2(x, axis=0)
+            y = self.func(x, axis=0)
             for j in range(nt):
-                assert_array_almost_equal(y[:,j], dct2(x[:,j]), decimal=self.dec)
+                assert_array_almost_equal(y[:,j], self.func(x[:,j]), decimal=self.dec)
 
-class TestDCTIIDouble(_TestDCTIIBase):
-    def setUp(self):
-        self.rdt = np.double
-        self.dec = 10
-
-class TestDCTIIFloat(_TestDCTIIBase):
-    def setUp(self):
-        self.rdt = np.float32
-        self.dec = 5
-
-class _TestDCTIIIBase(TestCase):
-    def setUp(self):
-        self.rdt = None
-        self.dec = 14
-
-    def test_definition(self):
+class _TestDCTIIBase(_TestDCTBase):
+    def test_definition_matlab(self):
+        """Test correspondance with matlab (orthornomal mode)."""
         for i in range(len(X)):
             x = np.array(X[i], dtype=self.rdt)
-            y = dct3(x)
+            yr = Y[i]
+            y = dct2(x, norm="ortho")
             self.failUnless(y.dtype == self.rdt,
                     "Output dtype is %s, expected %s" % (y.dtype, self.rdt))
-            assert_array_almost_equal(dct2(y) / (2*x.size), x,
-                    decimal=self.dec)
+            assert_array_almost_equal(y, yr, decimal=self.dec)
 
+class _TestDCTIIIBase(_TestDCTBase):
     def test_definition_ortho(self):
         """Test orthornomal mode."""
         for i in range(len(X)):
@@ -107,29 +89,47 @@
                     "Output dtype is %s, expected %s" % (xi.dtype, self.rdt))
             assert_array_almost_equal(xi, x, decimal=self.dec)
 
-    def test_axis(self):
-        nt = 2
-        for i in [7, 8, 9, 16, 32, 64]:
-            x = np.random.randn(nt, i)
-            y = dct3(x)
-            for j in range(nt):
-                assert_array_almost_equal(y[j], dct3(x[j]), decimal=self.dec)
+class TestDCTIDouble(_TestDCTBase):
+    def setUp(self):
+        self.rdt = np.double
+        self.dec = 10
+        self.type = 1
+        self.func = TYPE2DCT[self.type]
 
-            x = x.T
-            y = dct3(x, axis=0)
-            for j in range(nt):
-                assert_array_almost_equal(y[:,j], dct3(x[:,j]),
-                        decimal=self.dec)
+class TestDCTIFloat(_TestDCTBase):
+    def setUp(self):
+        self.rdt = np.float32
+        self.dec = 5
+        self.type = 1
+        self.func = TYPE2DCT[self.type]
 
+class TestDCTIIDouble(_TestDCTIIBase):
+    def setUp(self):
+        self.rdt = np.double
+        self.dec = 10
+        self.type = 2
+        self.func = TYPE2DCT[self.type]
+
+class TestDCTIIFloat(_TestDCTIIBase):
+    def setUp(self):
+        self.rdt = np.float32
+        self.dec = 5
+        self.type = 2
+        self.func = TYPE2DCT[self.type]
+
 class TestDCTIIIDouble(_TestDCTIIIBase):
     def setUp(self):
         self.rdt = np.double
         self.dec = 14
+        self.type = 3
+        self.func = TYPE2DCT[self.type]
 
 class TestDCTIIIFloat(_TestDCTIIIBase):
     def setUp(self):
         self.rdt = np.float32
         self.dec = 5
+        self.type = 3
+        self.func = TYPE2DCT[self.type]
 
 if __name__ == "__main__":
     np.testing.run_module_suite()



More information about the Scipy-svn mailing list