[Scipy-svn] r6682 - in trunk/scipy/interpolate: . tests

scipy-svn@scip... scipy-svn@scip...
Sun Sep 5 05:40:19 CDT 2010


Author: ptvirtan
Date: 2010-09-05 05:40:19 -0500 (Sun, 05 Sep 2010)
New Revision: 6682

Modified:
   trunk/scipy/interpolate/griddatand.py
   trunk/scipy/interpolate/tests/test_griddatand.py
Log:
ENH: interpolate: add a 'fill_value' keyword to griddata

Modified: trunk/scipy/interpolate/griddatand.py
===================================================================
--- trunk/scipy/interpolate/griddatand.py	2010-09-05 00:04:24 UTC (rev 6681)
+++ trunk/scipy/interpolate/griddatand.py	2010-09-05 10:40:19 UTC (rev 6682)
@@ -64,7 +64,7 @@
 # Convenience interface function
 #------------------------------------------------------------------------------
 
-def griddata(points, values, xi, method='linear'):
+def griddata(points, values, xi, method='linear', fill_value=np.nan):
     """
     griddata(points, values, xi, method='linear')
 
@@ -101,7 +101,13 @@
           approximately curvature-minimizing polynomial surface. See
           `CloughTocher2DInterpolator` for more details.
 
+    fill_value : float, optional
+        Value used to fill in for requested points outside of the
+        convex hull of the input points.  If not provided, then the
+        default is ``nan``. This option has no effect for the
+        'nearest' method.
 
+
     Examples
     --------
 
@@ -156,16 +162,17 @@
     ndim = points.shape[-1]
 
     if ndim == 1 and method in ('nearest', 'linear', 'cubic'):
-        ip = interp1d(points, values, kind=method, axis=0, bounds_error=False)
+        ip = interp1d(points, values, kind=method, axis=0, bounds_error=False,
+                      fill_value=fill_value)
         return ip(xi)
     elif method == 'nearest':
         ip = NearestNDInterpolator(points, values)
         return ip(xi)
     elif method == 'linear':
-        ip = LinearNDInterpolator(points, values)
+        ip = LinearNDInterpolator(points, values, fill_value=fill_value)
         return ip(xi)
     elif method == 'cubic' and ndim == 2:
-        ip = CloughTocher2DInterpolator(points, values)
+        ip = CloughTocher2DInterpolator(points, values, fill_value=fill_value)
         return ip(xi)
     else:
         raise ValueError("Unknown interpolation method %r for "

Modified: trunk/scipy/interpolate/tests/test_griddatand.py
===================================================================
--- trunk/scipy/interpolate/tests/test_griddatand.py	2010-09-05 00:04:24 UTC (rev 6681)
+++ trunk/scipy/interpolate/tests/test_griddatand.py	2010-09-05 10:40:19 UTC (rev 6682)
@@ -5,7 +5,16 @@
 
 
 class TestGriddata(object):
+    def test_fill_value(self):
+        x = [(0,0), (0,1), (1,0)]
+        y = [1, 2, 3]
 
+        yi = griddata(x, y, [(1,1), (1,2), (0,0)], fill_value=-1)
+        assert_array_equal(yi, [-1, -1, 1])
+
+        yi = griddata(x, y, [(1,1), (1,2), (0,0)])
+        assert_array_equal(yi, [np.nan, np.nan, 1])
+
     def test_alternative_call(self):
         x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
                      dtype=np.double)
@@ -14,7 +23,7 @@
 
         for method in ('nearest', 'linear', 'cubic'):
             yi = griddata((x[:,0], x[:,1]), y, (x[:,0], x[:,1]), method=method)
-            assert_almost_equal(y, yi, err_msg=method)
+            assert_allclose(y, yi, atol=1e-14, err_msg=method)
 
     def test_multivalue_2d(self):
         x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
@@ -24,7 +33,7 @@
 
         for method in ('nearest', 'linear', 'cubic'):
             yi = griddata(x, y, x, method=method)
-            assert_almost_equal(y, yi, err_msg=method)
+            assert_allclose(y, yi, atol=1e-14, err_msg=method)
 
     def test_multipoint_2d(self):
         x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
@@ -37,7 +46,8 @@
             yi = griddata(x, y, xi, method=method)
 
             assert_equal(yi.shape, (5, 3), err_msg=method)
-            assert_almost_equal(yi, np.tile(y[:,None], (1, 3)), err_msg=method)
+            assert_allclose(yi, np.tile(y[:,None], (1, 3)),
+                            atol=1e-14, err_msg=method)
 
     def test_complex_2d(self):
         x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
@@ -51,7 +61,8 @@
             yi = griddata(x, y, xi, method=method)
 
             assert_equal(yi.shape, (5, 3), err_msg=method)
-            assert_almost_equal(yi, np.tile(y[:,None], (1, 3)), err_msg=method)
+            assert_allclose(yi, np.tile(y[:,None], (1, 3)),
+                            atol=1e-14, err_msg=method)
 
 if __name__ == "__main__":
     run_module_suite()



More information about the Scipy-svn mailing list