[Scipy-svn] r3226 - in trunk/Lib/signal: . tests

scipy-svn@scip... scipy-svn@scip...
Fri Aug 10 19:08:27 CDT 2007


Author: stefan
Date: 2007-08-10 19:08:07 -0500 (Fri, 10 Aug 2007)
New Revision: 3226

Added:
   trunk/Lib/signal/tests/test_wavelets.py
Modified:
   trunk/Lib/signal/wavelets.py
Log:
Fix wavelet module.  Add tests.


Added: trunk/Lib/signal/tests/test_wavelets.py
===================================================================
--- trunk/Lib/signal/tests/test_wavelets.py	2007-08-08 15:15:24 UTC (rev 3225)
+++ trunk/Lib/signal/tests/test_wavelets.py	2007-08-11 00:08:07 UTC (rev 3226)
@@ -0,0 +1,26 @@
+import numpy as N
+from numpy.testing import *
+
+set_package_path()
+from scipy.signal import wavelets
+restore_path()
+
+class test_wavelets(NumpyTestCase):
+    def check_qmf(self):
+        assert_array_equal(wavelets.qmf([1,1]),[1,-1])
+
+    def check_daub(self):
+        for i in xrange(1,15):
+            assert_equal(len(wavelets.daub(i)),i*2)
+
+    def check_cascade(self):
+        for J in xrange(1,7):
+            for i in xrange(1,5):
+                lpcoef = wavelets.daub(i)
+                k = len(lpcoef)
+                x,phi,psi = wavelets.cascade(lpcoef,J)
+                assert len(x) == len(phi) == len(psi)
+                assert_equal(len(x),(k-1)*2**J)
+
+if __name__ == "__main__":
+    NumpyTest().run()

Modified: trunk/Lib/signal/wavelets.py
===================================================================
--- trunk/Lib/signal/wavelets.py	2007-08-08 15:15:24 UTC (rev 3225)
+++ trunk/Lib/signal/wavelets.py	2007-08-11 00:08:07 UTC (rev 3226)
@@ -1,3 +1,4 @@
+__all__ = ['daub','qmf','cascade']
 
 import numpy as sb
 from numpy.dual import eig
@@ -3,5 +4,4 @@
 from scipy.misc import comb
 
-
 def daub(p):
     """The coefficients for the FIR low-pass filter producing Daubechies wavelets.
@@ -47,7 +47,8 @@
             if (abs(z1)) < 1:
                 z1 = const - part
             q = q * [1,-z1]
-        q = sb.real(q) * c
+
+        q = c * sb.real(q)
         # Normalize result
         q = q / sb.sum(q) * sqrt(2)
         return q.c[::-1]
@@ -74,7 +75,7 @@
       J   -- values will be computed at grid points $K/2^J$
 
     Outputs:
-      x   -- the dyadic points $K/2^J$ for $K=0...N*2^J-1$
+      x   -- the dyadic points $K/2^J$ for $K=0...N*(2^J)-1$
               where len(hk)=len(gk)=N+1
       phi -- the scaling function phi(x) at x
                $\phi(x) = \sum_{k=0}^{N} h_k \phi(2x-k)$
@@ -118,7 +119,7 @@
     m *= s2
 
     # construct the grid of points
-    x = sb.arange(0,N*(1<<J),dtype=sb.Float) / (1<<J)
+    x = sb.arange(0,N*(1<<J),dtype=sb.float) / (1<<J)
     phi = 0*x
 
     psi = 0*x



More information about the Scipy-svn mailing list