[Scipy-svn] r4489 - in trunk/scipy/interpolate: . tests

scipy-svn@scip... scipy-svn@scip...
Sat Jun 28 08:55:25 CDT 2008


Author: ptvirtan
Date: 2008-06-28 08:54:11 -0500 (Sat, 28 Jun 2008)
New Revision: 4489

Modified:
   trunk/scipy/interpolate/interpolate.py
   trunk/scipy/interpolate/tests/test_interpolate.py
Log:
interpolate: Fix #289. Make interp1d order axes in the result correctly when y.ndim > 2. Fix a bug in splmake that was triggered when y.ndim > 2. Add corresponding tests.

Modified: trunk/scipy/interpolate/interpolate.py
===================================================================
--- trunk/scipy/interpolate/interpolate.py	2008-06-28 01:34:29 UTC (rev 4488)
+++ trunk/scipy/interpolate/interpolate.py	2008-06-28 13:54:11 UTC (rev 4489)
@@ -152,9 +152,6 @@
     UnivariateSpline - a more recent wrapper of the FITPACK routines
     """
 
-    _interp_axis = -1 # used to set which is default interpolation
-                      # axis.  DO NOT CHANGE OR CODE WILL BREAK.
-
     def __init__(self, x, y, kind='linear', axis=-1,
                  copy=True, bounds_error=True, fill_value=np.nan):
         """ Initialize a 1D linear interpolation class.
@@ -226,12 +223,18 @@
         if kind == 'linear':
             # Make a "view" of the y array that is rotated to the interpolation
             # axis.
-            oriented_y = y.swapaxes(self._interp_axis, axis)
+            axes = range(y.ndim)
+            del axes[self.axis]
+            axes.append(self.axis)
+            oriented_y = y.transpose(axes)
             minval = 2
-            len_y = oriented_y.shape[self._interp_axis]
+            len_y = oriented_y.shape[-1]
             self._call = self._call_linear
         else:
-            oriented_y = y.swapaxes(0, axis)
+            axes = range(y.ndim)
+            del axes[self.axis]
+            axes.insert(0, self.axis)
+            oriented_y = y.transpose(axes)
             minval = order + 1
             len_y = oriented_y.shape[0]
             self._call = self._call_spline
@@ -322,10 +325,10 @@
             return y_new.transpose(axes)
         else:
             y_new[out_of_bounds] = self.fill_value
-            axes = range(ny - nx, ny)
-            axes[self.axis:self.axis] = range(ny - nx)
+            axes = range(nx, ny)
+            axes[self.axis:self.axis] = range(nx)
             return y_new.transpose(axes)
-
+    
     def _check_bounds(self, x_new):
         """ Check the inputs for being in the bounds of the interpolated data.
 
@@ -407,6 +410,16 @@
     fromspline = classmethod(fromspline)
 
 
+def _dot0(a, b):
+    """Similar to numpy.dot, but sum over last axis of a and 1st axis of b"""
+    if b.ndim <= 2:
+        return dot(a, b)
+    else:
+        axes = range(b.ndim)
+        axes.insert(-1, 0)
+        axes.pop(0)
+        return dot(a, b.transpose(axes))
+
 def _find_smoothest(xk, yk, order, conds=None, B=None):
     # construct Bmatrix, and Jmatrix
     # e = J*c
@@ -431,9 +444,8 @@
     tmp = dot(tmp,V1)
     tmp = dot(tmp,np.diag(1.0/s))
     tmp = dot(tmp,u.T)
-    return dot(tmp, yk)
+    return _dot0(tmp, yk)
 
-
 def _setdiag(a, k, v):
     assert (a.ndim==2)
     M,N = a.shape
@@ -471,7 +483,7 @@
     V2[1::2] = -1
     V2 /= math.sqrt(Np1)
     dk = np.diff(xk)
-    b = 2*np.diff(yk)/dk
+    b = 2*np.diff(yk, axis=0)/dk
     J = np.zeros((N-1,N+1))
     idk = 1.0/dk
     _setdiag(J,0,idk[:-1])
@@ -480,7 +492,7 @@
     A = dot(J.T,J)
     val = dot(V2,dot(A,V2))
     res1 = dot(np.outer(V2,V2)/val,A)
-    mk = dot(np.eye(Np1)-res1,dot(Bd,b))
+    mk = dot(np.eye(Np1)-res1, _dot0(Bd,b))
     return mk
 
 def _get_spline2_Bb(xk, yk, kind, conds):

Modified: trunk/scipy/interpolate/tests/test_interpolate.py
===================================================================
--- trunk/scipy/interpolate/tests/test_interpolate.py	2008-06-28 01:34:29 UTC (rev 4488)
+++ trunk/scipy/interpolate/tests/test_interpolate.py	2008-06-28 13:54:11 UTC (rev 4489)
@@ -29,7 +29,7 @@
 
         self.y210 = np.arange(20.).reshape((2, 10))
         self.y102 = np.arange(20.).reshape((10, 2))
-
+        
         self.fill_value = -100.0
 
     def test_validation(self):
@@ -125,13 +125,30 @@
             np.array([2.4, 5.6, 6.0]),
         )
 
+    def test_cubic(self):
+        """ Check the actual implementation of spline interpolation.
+        """
 
-    def test_bounds(self):
+        interp10 = interp1d(self.x10, self.y10, kind='cubic')
+        assert_array_almost_equal(
+            interp10(self.x10),
+            self.y10,
+        )
+        assert_array_almost_equal(
+            interp10(1.2),
+            np.array([1.2]),
+        )
+        assert_array_almost_equal(
+            interp10([2.4, 5.6, 6.0]),
+            np.array([2.4, 5.6, 6.0]),
+        )
+
+    def _bounds_check(self, kind='linear'):
         """ Test that our handling of out-of-bounds input is correct.
         """
 
         extrap10 = interp1d(self.x10, self.y10, fill_value=self.fill_value,
-            bounds_error=False)
+            bounds_error=False, kind=kind)
         assert_array_equal(
             extrap10(11.2),
             np.array([self.fill_value]),
@@ -145,25 +162,28 @@
             np.array([True, False, False, False, True]),
         )
 
-        raises_bounds_error = interp1d(self.x10, self.y10, bounds_error=True)
+        raises_bounds_error = interp1d(self.x10, self.y10, bounds_error=True,
+                                       kind=kind)
         self.assertRaises(ValueError, raises_bounds_error, -1.0)
         self.assertRaises(ValueError, raises_bounds_error, 11.0)
         raises_bounds_error([0.0, 5.0, 9.0])
 
+    def test_bounds(self):
+        for kind in ('linear', 'cubic'):
+            self._bounds_check(kind=kind)
 
-    def test_nd(self):
+    def _nd_check(self, kind='linear'):
         """ Check the behavior when the inputs and outputs are multidimensional.
         """
-
         # Multidimensional input.
-        interp10 = interp1d(self.x10, self.y10)
+        interp10 = interp1d(self.x10, self.y10, kind=kind)
         assert_array_almost_equal(
             interp10(np.array([[3.4, 5.6], [2.4, 7.8]])),
             np.array([[3.4, 5.6], [2.4, 7.8]]),
         )
-
+        
         # Multidimensional outputs.
-        interp210 = interp1d(self.x10, self.y210)
+        interp210 = interp1d(self.x10, self.y210, kind=kind)
         assert_array_almost_equal(
             interp210(1.5),
             np.array([[1.5], [11.5]]),
@@ -174,7 +194,7 @@
                       [11.5, 12.4]]),
         )
 
-        interp102 = interp1d(self.x10, self.y102, axis=0)
+        interp102 = interp1d(self.x10, self.y102, axis=0, kind=kind)
         assert_array_almost_equal(
             interp102(1.5),
             np.array([[3.0, 4.0]]),
@@ -197,7 +217,24 @@
             np.array([[[6.8, 7.8], [11.2, 12.2]],
                       [[4.8, 5.8], [15.6, 16.6]]]),
         )
+        
+        # Check large ndim output
+        a = [4, 5, 6, 7]
+        y = np.arange(np.prod(a)).reshape(*a)
+        for n, s in enumerate(a):
+            x = np.arange(s)
+            z = interp1d(x, y, axis=n, kind=kind)
+            assert_array_almost_equal(z(x), y)
+            
+            x2 = np.arange(2*3*1).reshape((2,3,1)) / 12.
+            b = list(a)
+            b[n:n+1] = [2,3,1]
+            assert_array_almost_equal(z(x2).shape, b)
 
+    def test_nd(self):
+        for kind in ('linear', 'cubic'):
+            self._nd_check(kind=kind)
+
 class TestLagrange(TestCase):
 
     def test_lagrange(self):



More information about the Scipy-svn mailing list