[Scipy-svn] r6293 - in trunk/scipy/fftpack: . tests

scipy-svn@scip... scipy-svn@scip...
Wed Mar 31 01:49:33 CDT 2010


Author: cdavid
Date: 2010-03-31 01:49:33 -0500 (Wed, 31 Mar 2010)
New Revision: 6293

Modified:
   trunk/scipy/fftpack/basic.py
   trunk/scipy/fftpack/tests/test_basic.py
Log:
BUG: fix long double fft (fix #948)

Modified: trunk/scipy/fftpack/basic.py
===================================================================
--- trunk/scipy/fftpack/basic.py	2010-03-31 05:33:56 UTC (rev 6292)
+++ trunk/scipy/fftpack/basic.py	2010-03-31 06:49:33 UTC (rev 6293)
@@ -8,22 +8,55 @@
 __all__ = ['fft','ifft','fftn','ifftn','rfft','irfft',
            'fft2','ifft2', 'rfftfreq']
 
-from numpy import asarray, zeros, swapaxes, integer, array
+from numpy import zeros, swapaxes, integer, array
 import numpy
-import _fftpack as fftpack
+import _fftpack
 
 import atexit
-atexit.register(fftpack.destroy_zfft_cache)
-atexit.register(fftpack.destroy_zfftnd_cache)
-atexit.register(fftpack.destroy_drfft_cache)
-atexit.register(fftpack.destroy_cfft_cache)
-atexit.register(fftpack.destroy_cfftnd_cache)
-atexit.register(fftpack.destroy_rfft_cache)
+atexit.register(_fftpack.destroy_zfft_cache)
+atexit.register(_fftpack.destroy_zfftnd_cache)
+atexit.register(_fftpack.destroy_drfft_cache)
+atexit.register(_fftpack.destroy_cfft_cache)
+atexit.register(_fftpack.destroy_cfftnd_cache)
+atexit.register(_fftpack.destroy_rfft_cache)
 del atexit
 
 def istype(arr, typeclass):
     return issubclass(arr.dtype.type, typeclass)
 
+_DTYPE_TO_FFT = {
+        numpy.dtype(numpy.float32): _fftpack.crfft,
+        numpy.dtype(numpy.float64): _fftpack.zrfft,
+        numpy.dtype(numpy.complex64): _fftpack.cfft,
+        numpy.dtype(numpy.complex128): _fftpack.zfft,
+}
+
+_DTYPE_TO_RFFT = {
+        numpy.dtype(numpy.float32): _fftpack.rfft,
+        numpy.dtype(numpy.float64): _fftpack.drfft,
+}
+
+_DTYPE_TO_FFTN = {
+        numpy.dtype(numpy.complex64): _fftpack.cfftnd,
+        numpy.dtype(numpy.complex128): _fftpack.zfftnd,
+        numpy.dtype(numpy.float32): _fftpack.cfftnd,
+        numpy.dtype(numpy.float64): _fftpack.zfftnd,
+}
+
+def _asfarray(x):
+    """Like numpy asfarray, except that it does not modify x dtype if x is
+    already an array with a float dtype, and do not cast complex types to
+    real."""
+    if hasattr(x, "dtype") and x.dtype.char in numpy.typecodes["AllFloat"]:
+        return x
+    else:
+        # We cannot use asfarray directly because it converts sequences of
+        # complex to sequence of real
+        ret = numpy.asarray(x)
+        if not ret.dtype.char in numpy.typecodes["AllFloat"]:
+            return numpy.asfarray(x)
+        return ret
+
 def _fix_shape(x, n, axis):
     """ Internal auxiliary function for _raw_fft, _raw_fftnd."""
     s = list(x.shape)
@@ -106,21 +139,21 @@
     True
 
     """
-    tmp = asarray(x)
+    tmp = _asfarray(x)
+
+    try:
+        work_function = _DTYPE_TO_FFT[tmp.dtype]
+    except KeyError:
+        raise ValueError("type %s is not supported" % tmp.dtype)
+
     if istype(tmp, numpy.complex128):
         overwrite_x = overwrite_x or (tmp is not x and not \
                                       hasattr(x,'__array__'))
-        work_function = fftpack.zfft
     elif istype(tmp, numpy.complex64):
         overwrite_x = overwrite_x or (tmp is not x and not \
                                       hasattr(x,'__array__'))
-        work_function = fftpack.cfft
     else:
         overwrite_x = 1
-        if istype(tmp, numpy.float32):
-            work_function = fftpack.crfft
-        else:
-            work_function = fftpack.zrfft
 
     #return _raw_fft(tmp,n,axis,1,overwrite_x,work_function)
     if n is None:
@@ -149,21 +182,21 @@
 
     Optional input: see fft.__doc__
     """
-    tmp = asarray(x)
+    tmp = _asfarray(x)
+
+    try:
+        work_function = _DTYPE_TO_FFT[tmp.dtype]
+    except KeyError:
+        raise ValueError("type %s is not supported" % tmp.dtype)
+
     if istype(tmp, numpy.complex128):
         overwrite_x = overwrite_x or (tmp is not x and not \
                                       hasattr(x,'__array__'))
-        work_function = fftpack.zfft
     elif istype(tmp, numpy.complex64):
         overwrite_x = overwrite_x or (tmp is not x and not \
                                       hasattr(x,'__array__'))
-        work_function = fftpack.cfft
     else:
         overwrite_x = 1
-        if istype(tmp, numpy.float32):
-            work_function = fftpack.crfft
-        else:
-            work_function = fftpack.zrfft
 
     #return _raw_fft(tmp,n,axis,-1,overwrite_x,work_function)
     if n is None:
@@ -207,13 +240,16 @@
     Notes:
       y == rfft(irfft(y)) within numerical accuracy.
     """
-    tmp = asarray(x)
+    tmp = _asfarray(x)
+
     if not numpy.isrealobj(tmp):
         raise TypeError,"1st argument must be real sequence"
-    if istype(tmp, numpy.float32):
-        work_function = fftpack.rfft
-    else:
-        work_function = fftpack.drfft
+
+    try:
+        work_function = _DTYPE_TO_RFFT[tmp.dtype]
+    except KeyError:
+        raise ValueError("type %s is not supported" % tmp.dtype)
+
     return _raw_fft(tmp,n,axis,1,overwrite_x,work_function)
 
 
@@ -254,13 +290,15 @@
 
     Optional input: see rfft.__doc__
     """
-    tmp = asarray(x)
+    tmp = _asfarray(x)
     if not numpy.isrealobj(tmp):
         raise TypeError,"1st argument must be real sequence"
-    if istype(tmp, numpy.float32):
-        work_function = fftpack.rfft
-    else:
-        work_function = fftpack.drfft
+
+    try:
+        work_function = _DTYPE_TO_RFFT[tmp.dtype]
+    except KeyError:
+        raise ValueError("type %s is not supported" % tmp.dtype)
+
     return _raw_fft(tmp,n,axis,-1,overwrite_x,work_function)
 
 def _raw_fftnd(x, s, axes, direction, overwrite_x, work_function):
@@ -354,19 +392,20 @@
     return _raw_fftn_dispatch(x, shape, axes, overwrite_x, 1)
 
 def _raw_fftn_dispatch(x, shape, axes, overwrite_x, direction):
-    tmp = asarray(x)
+    tmp = _asfarray(x)
+
+    try:
+        work_function = _DTYPE_TO_FFTN[tmp.dtype]
+    except KeyError:
+        raise ValueError("type %s is not supported" % tmp.dtype)
+
     if istype(tmp, numpy.complex128):
         overwrite_x = overwrite_x or (tmp is not x and not \
                                       hasattr(x,'__array__'))
-        work_function = fftpack.zfftnd
     elif istype(tmp, numpy.complex64):
-        work_function = fftpack.cfftnd
+        pass
     else:
         overwrite_x = 1
-        if istype(tmp, numpy.float32):
-            work_function = fftpack.cfftnd
-        else:
-            work_function = fftpack.zfftnd
     return _raw_fftnd(tmp,shape,axes,direction,overwrite_x,work_function)
 
 

Modified: trunk/scipy/fftpack/tests/test_basic.py
===================================================================
--- trunk/scipy/fftpack/tests/test_basic.py	2010-03-31 05:33:56 UTC (rev 6292)
+++ trunk/scipy/fftpack/tests/test_basic.py	2010-03-31 06:49:33 UTC (rev 6293)
@@ -515,5 +515,29 @@
     cdtype = np.complex64
     maxnlp = 2000
 
+class TestLongDoubleFailure(TestCase):
+    def test_complex(self):
+        x = np.random.randn(10).astype(np.longdouble) + \
+                1j * np.random.randn(10).astype(np.longdouble)
+
+        for f in [fft, ifft]:
+            try:
+                f(x)
+                raise AssertionError("Type %r not supported but does not fail" % \
+                                     np.longcomplex)
+            except ValueError:
+                pass
+
+    def test_real(self):
+        x = np.random.randn(10).astype(np.longcomplex)
+
+        for f in [fft, ifft]:
+            try:
+                f(x)
+                raise AssertionError("Type %r not supported but does not fail" % \
+                                     np.longcomplex)
+            except ValueError:
+                pass
+
 if __name__ == "__main__":
     run_module_suite()



More information about the Scipy-svn mailing list