[Scipy-svn] r5055 - in trunk/scipy/interpolate: . tests

scipy-svn@scip... scipy-svn@scip...
Mon Nov 10 16:50:42 CST 2008


Author: ptvirtan
Date: 2008-11-10 16:50:30 -0600 (Mon, 10 Nov 2008)
New Revision: 5055

Modified:
   trunk/scipy/interpolate/interpolate.py
   trunk/scipy/interpolate/tests/test_interpolate.py
Log:
Make interp1d treat scalars differently from 1-d arrays (fixes #660)

Modified: trunk/scipy/interpolate/interpolate.py
===================================================================
--- trunk/scipy/interpolate/interpolate.py	2008-11-10 19:01:13 UTC (rev 5054)
+++ trunk/scipy/interpolate/interpolate.py	2008-11-10 22:50:30 UTC (rev 5055)
@@ -286,7 +286,7 @@
         return result.reshape(x_new.shape+result.shape[1:])
 
     def __call__(self, x_new):
-        """ Find linearly interpolated y_new = f(x_new).
+        """Find interpolated y_new = f(x_new).
 
         Parameters
         ----------
@@ -296,13 +296,14 @@
         Returns
         -------
         y_new : number or array
-            Linearly interpolated value(s) corresponding to x_new.
+            Interpolated value(s) corresponding to x_new.
+
         """
 
         # 1. Handle values in x_new that are outside of x.  Throw error,
         #    or return a list of mask array indicating the outofbounds values.
         #    The behavior is set by the bounds_error variable.
-        x_new = atleast_1d(x_new)
+        x_new = asarray(x_new)
         out_of_bounds = self._check_bounds(x_new)
 
         y_new = self._call(x_new)
@@ -318,7 +319,15 @@
         # and
         # 7. Rotate the values back to their proper place.
 
-        if self._kind == 'linear':
+        if nx == 0:
+            # special case: x is a scalar
+            if out_of_bounds:
+                if ny == 0:
+                    return self.fill_value
+                else:
+                    y_new[...] = self.fill_value
+            return y_new
+        elif self._kind == 'linear':
             y_new[..., out_of_bounds] = self.fill_value
             axes = range(ny - nx)
             axes[self.axis:self.axis] = range(ny - nx, ny)
@@ -330,7 +339,7 @@
             return y_new.transpose(axes)
 
     def _check_bounds(self, x_new):
-        """ Check the inputs for being in the bounds of the interpolated data.
+        """Check the inputs for being in the bounds of the interpolated data.
 
         Parameters
         ----------

Modified: trunk/scipy/interpolate/tests/test_interpolate.py
===================================================================
--- trunk/scipy/interpolate/tests/test_interpolate.py	2008-11-10 19:01:13 UTC (rev 5054)
+++ trunk/scipy/interpolate/tests/test_interpolate.py	2008-11-10 22:50:30 UTC (rev 5055)
@@ -158,13 +158,17 @@
             bounds_error=False, kind=kind)
         assert_array_equal(
             extrap10(11.2),
-            np.array([self.fill_value]),
+            np.array(self.fill_value),
         )
         assert_array_equal(
             extrap10(-3.4),
-            np.array([self.fill_value]),
+            np.array(self.fill_value),
         )
         assert_array_equal(
+            extrap10([[[11.2], [-3.4], [12.6], [19.3]]]),
+            np.array(self.fill_value),
+        )
+        assert_array_equal(
             extrap10._check_bounds(np.array([-1.0, 0.0, 5.0, 9.0, 11.0])),
             np.array([True, False, False, False, True]),
         )
@@ -193,7 +197,7 @@
         interp210 = interp1d(self.x10, self.y210, kind=kind)
         assert_array_almost_equal(
             interp210(1.5),
-            np.array([[1.5], [11.5]]),
+            np.array([1.5, 11.5]),
         )
         assert_array_almost_equal(
             interp210(np.array([1.5, 2.4])),
@@ -204,7 +208,7 @@
         interp102 = interp1d(self.x10, self.y102, axis=0, kind=kind)
         assert_array_almost_equal(
             interp102(1.5),
-            np.array([[3.0, 4.0]]),
+            np.array([3.0, 4.0]),
         )
         assert_array_almost_equal(
             interp102(np.array([1.5, 2.4])),



More information about the Scipy-svn mailing list