# [Scipy-svn] r4489 - in trunk/scipy/interpolate: . tests

scipy-svn@scip... scipy-svn@scip...
Sat Jun 28 08:55:25 CDT 2008

```Author: ptvirtan
Date: 2008-06-28 08:54:11 -0500 (Sat, 28 Jun 2008)
New Revision: 4489

Modified:
trunk/scipy/interpolate/interpolate.py
trunk/scipy/interpolate/tests/test_interpolate.py
Log:
interpolate: Fix #289. Make interp1d order axes in the result correctly when y.ndim > 2. Fix a bug in splmake that was triggered when y.ndim > 2. Add corresponding tests.

Modified: trunk/scipy/interpolate/interpolate.py
===================================================================
--- trunk/scipy/interpolate/interpolate.py	2008-06-28 01:34:29 UTC (rev 4488)
+++ trunk/scipy/interpolate/interpolate.py	2008-06-28 13:54:11 UTC (rev 4489)
@@ -152,9 +152,6 @@
UnivariateSpline - a more recent wrapper of the FITPACK routines
"""

-    _interp_axis = -1 # used to set which is default interpolation
-                      # axis.  DO NOT CHANGE OR CODE WILL BREAK.
-
def __init__(self, x, y, kind='linear', axis=-1,
copy=True, bounds_error=True, fill_value=np.nan):
""" Initialize a 1D linear interpolation class.
@@ -226,12 +223,18 @@
if kind == 'linear':
# Make a "view" of the y array that is rotated to the interpolation
# axis.
-            oriented_y = y.swapaxes(self._interp_axis, axis)
+            axes = range(y.ndim)
+            del axes[self.axis]
+            axes.append(self.axis)
+            oriented_y = y.transpose(axes)
minval = 2
-            len_y = oriented_y.shape[self._interp_axis]
+            len_y = oriented_y.shape[-1]
self._call = self._call_linear
else:
-            oriented_y = y.swapaxes(0, axis)
+            axes = range(y.ndim)
+            del axes[self.axis]
+            axes.insert(0, self.axis)
+            oriented_y = y.transpose(axes)
minval = order + 1
len_y = oriented_y.shape[0]
self._call = self._call_spline
@@ -322,10 +325,10 @@
return y_new.transpose(axes)
else:
y_new[out_of_bounds] = self.fill_value
-            axes = range(ny - nx, ny)
-            axes[self.axis:self.axis] = range(ny - nx)
+            axes = range(nx, ny)
+            axes[self.axis:self.axis] = range(nx)
return y_new.transpose(axes)
-
+
def _check_bounds(self, x_new):
""" Check the inputs for being in the bounds of the interpolated data.

@@ -407,6 +410,16 @@
fromspline = classmethod(fromspline)

+def _dot0(a, b):
+    """Similar to numpy.dot, but sum over last axis of a and 1st axis of b"""
+    if b.ndim <= 2:
+        return dot(a, b)
+    else:
+        axes = range(b.ndim)
+        axes.insert(-1, 0)
+        axes.pop(0)
+        return dot(a, b.transpose(axes))
+
def _find_smoothest(xk, yk, order, conds=None, B=None):
# construct Bmatrix, and Jmatrix
# e = J*c
@@ -431,9 +444,8 @@
tmp = dot(tmp,V1)
tmp = dot(tmp,np.diag(1.0/s))
tmp = dot(tmp,u.T)
-    return dot(tmp, yk)
+    return _dot0(tmp, yk)

-
def _setdiag(a, k, v):
assert (a.ndim==2)
M,N = a.shape
@@ -471,7 +483,7 @@
V2[1::2] = -1
V2 /= math.sqrt(Np1)
dk = np.diff(xk)
-    b = 2*np.diff(yk)/dk
+    b = 2*np.diff(yk, axis=0)/dk
J = np.zeros((N-1,N+1))
idk = 1.0/dk
_setdiag(J,0,idk[:-1])
@@ -480,7 +492,7 @@
A = dot(J.T,J)
val = dot(V2,dot(A,V2))
res1 = dot(np.outer(V2,V2)/val,A)
-    mk = dot(np.eye(Np1)-res1,dot(Bd,b))
+    mk = dot(np.eye(Np1)-res1, _dot0(Bd,b))
return mk

def _get_spline2_Bb(xk, yk, kind, conds):

Modified: trunk/scipy/interpolate/tests/test_interpolate.py
===================================================================
--- trunk/scipy/interpolate/tests/test_interpolate.py	2008-06-28 01:34:29 UTC (rev 4488)
+++ trunk/scipy/interpolate/tests/test_interpolate.py	2008-06-28 13:54:11 UTC (rev 4489)
@@ -29,7 +29,7 @@

self.y210 = np.arange(20.).reshape((2, 10))
self.y102 = np.arange(20.).reshape((10, 2))
-
+
self.fill_value = -100.0

def test_validation(self):
@@ -125,13 +125,30 @@
np.array([2.4, 5.6, 6.0]),
)

+    def test_cubic(self):
+        """ Check the actual implementation of spline interpolation.
+        """

-    def test_bounds(self):
+        interp10 = interp1d(self.x10, self.y10, kind='cubic')
+        assert_array_almost_equal(
+            interp10(self.x10),
+            self.y10,
+        )
+        assert_array_almost_equal(
+            interp10(1.2),
+            np.array([1.2]),
+        )
+        assert_array_almost_equal(
+            interp10([2.4, 5.6, 6.0]),
+            np.array([2.4, 5.6, 6.0]),
+        )
+
+    def _bounds_check(self, kind='linear'):
""" Test that our handling of out-of-bounds input is correct.
"""

extrap10 = interp1d(self.x10, self.y10, fill_value=self.fill_value,
-            bounds_error=False)
+            bounds_error=False, kind=kind)
assert_array_equal(
extrap10(11.2),
np.array([self.fill_value]),
@@ -145,25 +162,28 @@
np.array([True, False, False, False, True]),
)

-        raises_bounds_error = interp1d(self.x10, self.y10, bounds_error=True)
+        raises_bounds_error = interp1d(self.x10, self.y10, bounds_error=True,
+                                       kind=kind)
self.assertRaises(ValueError, raises_bounds_error, -1.0)
self.assertRaises(ValueError, raises_bounds_error, 11.0)
raises_bounds_error([0.0, 5.0, 9.0])

+    def test_bounds(self):
+        for kind in ('linear', 'cubic'):
+            self._bounds_check(kind=kind)

-    def test_nd(self):
+    def _nd_check(self, kind='linear'):
""" Check the behavior when the inputs and outputs are multidimensional.
"""
-
# Multidimensional input.
-        interp10 = interp1d(self.x10, self.y10)
+        interp10 = interp1d(self.x10, self.y10, kind=kind)
assert_array_almost_equal(
interp10(np.array([[3.4, 5.6], [2.4, 7.8]])),
np.array([[3.4, 5.6], [2.4, 7.8]]),
)
-
+
# Multidimensional outputs.
-        interp210 = interp1d(self.x10, self.y210)
+        interp210 = interp1d(self.x10, self.y210, kind=kind)
assert_array_almost_equal(
interp210(1.5),
np.array([[1.5], [11.5]]),
@@ -174,7 +194,7 @@
[11.5, 12.4]]),
)

-        interp102 = interp1d(self.x10, self.y102, axis=0)
+        interp102 = interp1d(self.x10, self.y102, axis=0, kind=kind)
assert_array_almost_equal(
interp102(1.5),
np.array([[3.0, 4.0]]),
@@ -197,7 +217,24 @@
np.array([[[6.8, 7.8], [11.2, 12.2]],
[[4.8, 5.8], [15.6, 16.6]]]),
)
+
+        # Check large ndim output
+        a = [4, 5, 6, 7]
+        y = np.arange(np.prod(a)).reshape(*a)
+        for n, s in enumerate(a):
+            x = np.arange(s)
+            z = interp1d(x, y, axis=n, kind=kind)
+            assert_array_almost_equal(z(x), y)
+
+            x2 = np.arange(2*3*1).reshape((2,3,1)) / 12.
+            b = list(a)
+            b[n:n+1] = [2,3,1]
+            assert_array_almost_equal(z(x2).shape, b)

+    def test_nd(self):
+        for kind in ('linear', 'cubic'):
+            self._nd_check(kind=kind)
+
class TestLagrange(TestCase):

def test_lagrange(self):

```