[Scipy-svn] r5469 - trunk/scipy/fftpack/tests

scipy-svn@scip... scipy-svn@scip...
Fri Jan 16 01:34:43 CST 2009

Author: cdavid
Date: 2009-01-16 01:34:39 -0600 (Fri, 16 Jan 2009)
New Revision: 5469

Modified:
trunk/scipy/fftpack/tests/test_real_transforms.py
Log:
Add tests for wrapper around fftpack for DCT II.

Modified: trunk/scipy/fftpack/tests/test_real_transforms.py
===================================================================
--- trunk/scipy/fftpack/tests/test_real_transforms.py	2009-01-16 07:34:25 UTC (rev 5468)
+++ trunk/scipy/fftpack/tests/test_real_transforms.py	2009-01-16 07:34:39 UTC (rev 5469)
@@ -3,9 +3,10 @@

import numpy as np
from numpy.fft import fft as numfft
-from numpy.testing import assert_array_almost_equal
+from numpy.testing import assert_array_almost_equal, TestCase

from scipy.io import loadmat
+from scipy.fftpack.realtransforms import dct1, dct2

TDATA = loadmat(join(dirname(__file__), 'test.mat'),
squeeze_me=True,  struct_as_record=True, mat_dtype=True)
@@ -50,10 +51,12 @@
Note that it is not 'normalized'
"""
n = x.size
-    a = np.empty((n, n), dtype = x.dtype)
-    for i in xrange(n):
-        for j in xrange(n):
-            a[i, j] = x[j] * np.cos(np.pi * (0.5 + j) * i / n)
+    #a = np.empty((n, n), dtype = x.dtype)
+    #for i in xrange(n):
+    #    for j in xrange(n):
+    #        a[i, j] = x[j] * np.cos(np.pi * (0.5 + j) * i / n)
+    grd = np.outer(np.linspace(0, n - 1, n),  np.linspace(0.5, 0.5 + n - 1, n))
+    a = np.cos(np.pi / n * grd) * x

return 2 * a.sum(axis = 1)

@@ -106,5 +109,26 @@
for i in range(len(X)):
assert_array_almost_equal(direct_dct2(X[i]), fdct2(X[i]))

+class _TestDCTIIBase(TestCase):
+    def setUp(self):
+        self.rdt = None
+
+    def test_definition(self):
+        for i in range(len(X)):
+            x = np.array(X[i], dtype=self.rdt)
+            yr = direct_dct2(x)
+            y = dct2(x)
+            self.failUnless(y.dtype == self.rdt,
+                    "Output dtype is %s, expected %s" % (y.dtype, self.rdt))
+            assert_array_almost_equal(y, yr)
+
+class TestDCTIIDouble(_TestDCTIIBase):
+    def setUp(self):
+        self.rdt = np.double
+
+class TestDCTIIFloat(_TestDCTIIBase):
+    def setUp(self):
+        self.rdt = np.double
+
if __name__ == "__main__":
np.testing.run_module_suite()