[Scipy-svn] r4581 - in branches/Interpolate1D: . docs

scipy-svn@scip... scipy-svn@scip...
Wed Jul 30 15:11:38 CDT 2008


Author: fcady
Date: 2008-07-30 15:11:37 -0500 (Wed, 30 Jul 2008)
New Revision: 4581

Modified:
   branches/Interpolate1D/docs/tutorial.rst
   branches/Interpolate1D/interpolate1d.py
Log:
documentation extensively bolstered.  Hoping for feedback from the scipy development community

Modified: branches/Interpolate1D/docs/tutorial.rst
===================================================================
--- branches/Interpolate1D/docs/tutorial.rst	2008-07-30 14:43:00 UTC (rev 4580)
+++ branches/Interpolate1D/docs/tutorial.rst	2008-07-30 20:11:37 UTC (rev 4581)
@@ -1,15 +1,26 @@
+==================
 Overview
---------
+==================
 
-The interpolate package provides tools for interpolating and extrapolating new data points from a set known set of data points.  Intepolate provides both a functional interface that is flexible and easy to use as well as an object oriented interface that can be more efficient and flexible for some cases.  It is able to interpolate and extrapolate in 1D, 2D, and even N dimensions.[fixme: 1D only right now]  
+The interpolate package provides tools for interpolating and extrapolating new data points from a known set of data points.  
+Interpolate provides both a functional interface that is flexible and easy to use as well as an object oriented interface that 
+can be more efficient and flexible for some cases.  It is able to interpolate and extrapolate in 1D, 2D, and even N 
+dimensions.*[fixme: 1D only right now]*
 
-For 1D interpolation, it handles linear and spline(cubic, quadratic, and quintic) for both uniformly and non-uniformly spaced data points "out of the box."  Users can control the behavior of values that fall outside of the range of interpolation either by When new values fall outside of the range of interpolation data, the tools can be   
+For 1D interpolation, it handles linear and spline(cubic, quadratic, and quintic) for both uniformly and non-uniformly spaced 
+data points "out of the box."  Users can control the behavior of values that fall outside of the range of interpolation either 
+by When new values fall outside of the range of interpolation data, the tools can be   
 
 For 2D interpolation, 
 
+================================================
 1D Interpolation with the Functional Interface
-----------------------------------------------
+================================================
 
+-------------
+Basic Usage
+-------------
+
 The following example uses the 'interp1d' function to linearly interpolate a sin curve from a sparse set of values.::
 
 	# start up ipython for our examples.
@@ -18,16 +29,180 @@
 	In [1]: from interpolate import interp1d
 	
 	# Create our "known" set of 5 points with the x values in one array and the y values in another.
-	In [2]: x = linspace(0,2*pi,5)
+	In [2]: x = linspace(0, 2*pi, 5)
 	In [3]: y = sin(x)
-	
-	# Now interpolate from these x,y values to create a more dense set of new_x, new_y values.
-	In [4]: new_x = linspace(0,2*pi, 21)
-	In [5]: new_y = interp1d(x,y, new_x)
-	
+    
+    # If we only want a value at a single point, we can pass in a scalar and interp1d
+    # will return a scalar
+    In [9]: interp1d(x, y, 1.2)
+    Out [10]: 0.76394372684109768
+    
+    # 0-dimensional arrays are also treated as scalars
+    In [9]: interp1d(x, y, array(1.2) )
+    Out [10]: 0.76394372684109768
+    
+	# To interpolate from these x,y values at multiple points, possibly to get a more dense set of new_x, new_y values,
+    # pass a numpy array to interp1d, and the return type will also be a numpy array.
+	In [4]: new_x = linspace(0, 2*pi, 21)
+	In [5]: new_y = interp1d(x, y, new_x)
+
 	# Plot the results using matplotlib. [note examples assume you are running in ipython -pylab]
-	In [6]: plot(x,y,'ro', new_x, new_y, 'b-')
-	
+	In [6]: plot(x, y, 'ro', new_x, new_y, 'b-')
+    
 .. image:: interp1d_linear_simple.png
 
+::
+    
+    # Alternatively, x, y and new_x can also be lists (they are internally converted into arrays
+    # before processing)
+    In []: interp1d( [1.0, 2.0], [1.0, 2.0], [1.3] )
+    Out []: array([ 1.3 ])
+	
+
+
+What happens if we pass in a new_x with values outside the range of x?  By default, new_y will be
+NaN at all such points: ::
+
+    # If we attempt to extrapolate values outside the interpolation range, interp1d defaults
+    # to returning NaN
+    In [7]: interp1d(x, y, array([-2, -1, 1, 2]))
+    Out [8]: array([        NaN,     NaN,     0.63661977,   0.72676046])
+
+
+If we want a type of interpolation other than linear, there is a range of options which we can specify 
+with the keyword argument interp, which is usually a string.  For example::
+
+    # If we want quadratic (2nd order) spline interpolation, we can use the string 'quadratic'
+    In [7]: new_y_quadratic = interp1d(x, y, new_x, interp = 'quadratic')
+    In [8]: plot(x, y, 'r', new_x, new_y_quadratic, 'g')
+    
+.. image:: interp1d_linear_and_quadratic.png
+
+
+There is a large selection of strings which specify a range of interpolation methods.  The list includes:
+
+#. 'linear' : linear interpolation, same as the default
+#. 'block' : "round new_x down" to the nearest values where we know y.
+#. 'spline' : spline interpolation of default order (currently 3)
+#. 'cubic' : 3rd order spline interpolation
+#. 'quartic' : 4th order spline interpolation
+#. 'quintic' : 5th order spline interpolation
+
+The same flexibility is afforded for extrapolation by the keywords extrap_low and extrap_high: ::
+
+    In []: z = array([ 1.0, 2.0 ])
+    In []: interp1d(z, z, array([-5.0, 5.0]), extrap_low = 'linear', extrap_high = 'linear')
+    Out []: array([-5.0, 5.0])
+
+If a string is passed which is not recognized, and error will be raised.
+
+Finally, interp, extrap_low, and extrap_high can be set to default return values (just make sure that
+the return values are not callable and are not strings): ::
+
+    In []: interp1d(x, y, array([ -5.0, 1.1, 100 ]), interp = 8.2, extrap_low = 7.2, extrap_high = 9.2)
+    Out []: array([ 7.2, 8.2, 9.2 ])
+    
+It is also possible, though slightly trickier, to define your own interpolation methods and pass them
+in to interp, extrap_low, and extrap_high.  For more information, see "User-defined Interpolation Methods"
+below.
+
+
+
+-----------------------------
+Removal of Bad Datapoints
+-----------------------------
+
+Many datasets have missing or corrupt data which it is desirable to ignore when interpolating,
+and to this end, interp1d has the keyword argument bad_data.
+
+bad_data defaults to being None.  But if it is a list, all "bad" points (x[i], y[i]) will be removed
+before any interpolation is performed.  A point is "bad" if
+1) either x[i] or y[i] is in bad_data, or
+2) either x[i] or y[i] is NaN
+
+Note that bad_data must be either None or a list of numbers.  Including NaN or None in the list,
+for example, is not supported.  NaNs are removed anyway, and None must not appear in the
+data. ::
+
+    # the bad_data
+    In []: x = arange(10.); y = arange(10.)
+    In []: x[1] = NaN # bad data
+    In []: y[2] = 55   # bad data
+    In []: new_x = arange(0, 10., .1)
+    In []: new_y_bad = interp1d(x, y, new_x)
+    In []: new_y_no_bad = interp1d(x, y, new_x, bad_data=[55])
+    In []: plot(new_x, new_y_bad, 'r', new_x, new_y_no_bad, 'g')
+    
+.. image :: with_and_without_bad_data.png
+
+
+
+--------------------------------------
+User-defined Interpolation Methods
+--------------------------------------
+
+If you want more direct control than is afforded by the string interface, this is also possible.
+Note, however, that this is not for the faint-hearted.  You must be very careful to have correct
+format, and failure to do so can cause a range of errors.
+
+interp can also be set to a function, a callable class, or an instance of a callable class.  If you do this, however, you will
+have to be careful about formatting.
+
+If a function is passed, it will be called when interpolating.
+It is assumed to have the form ::
+
+        newy = interp(x, y, newx, **kw)
+        
+where x, y, newx, and newy are all numpy arrays.
+            
+If a callable class is passed, it is assumed to have format::
+
+        instance = Class(x, y, **kw).
+        
+which can then be called by
+
+            new_y = instance(new_x)
+            
+If a callable object with method "init_xy" or "set_xy" is
+passed, that method will be used to set x and y as follows: ::
+
+        instance.set_xy(x, y, **kw)
+        
+and the object will be called during interpolation.
+
+                new_y = instance(new_x)
+                
+If the "init_xy" and "set_xy" are not present, it will be called as
+
+                new_y = argument(new_x)
+                
+A primitive type which is not a string signifies a function
+which is identically that value (e.g. val and 
+lambda x, y, newx : val are equivalent). ::
+
+    # However, this behavior can be overwritten in the same way as linear interpolation,
+    # by setting the keyword extrap_low (for values below the range of interpolation) and
+    # extrap_high (for values above that range)
+    In []: def dummy(x, y, newx, default = 5.1):
+                # Note that dummy has acceptable form
+                return np.array([ default ])
+    In []: class Phony:
+                def __init__(self, val = 4.0):
+                    self.val = val
+                def init_xy(self, x, y):
+                    pass
+                def __call__(self, newx):
+                    return self.val
+    In []: x = arange(5.0)
+    In []: y = arange(5.0)
+    In []: new_x = np.array([ -1, .4, 7 ])
+    In []: new_y = interp1d(x, y, interp = Phony, 
+                                        interpkw = {'val':1.0},
+                                        extrap_low = dummy,
+                                        lowkw = {'default':7.1},
+                                        extrap_high = dummy
+                                        )
+    In []: new_y
+    Out []: array([ 7.1, 1.0, 4.0 ])
+
  
\ No newline at end of file

Modified: branches/Interpolate1D/interpolate1d.py
===================================================================
--- branches/Interpolate1D/interpolate1d.py	2008-07-30 14:43:00 UTC (rev 4580)
+++ branches/Interpolate1D/interpolate1d.py	2008-07-30 20:11:37 UTC (rev 4581)
@@ -5,9 +5,30 @@
 from fitpack_wrapper import Spline
 import numpy as np
 from numpy import array, arange, empty, float64, NaN
-    
-def interp1d(x, y, new_x, interp = 'linear', extrap_low = NaN, extrap_high = NaN,
+
+# dictionary of tuples.  First element is a callable (class, instance of a class, or function
+# second argument is dictionary of additional keywords, if any
+dict_of_interp_types = \
+                { 'linear' : (linear, {}), 
+                    'logarithmic' : (logarithmic, {}), 
+                    'block' : (block, {}),
+                    'block_average_above' : (block_average_above, {}),
+                    'Spline' : (Spline, {}), 'spline' : (Spline, {}),
+                    'Quadratic' : (Spline, {'k':2}), 'quadratic' : (Spline, {'k':2}),
+                    'Quad' : (Spline, {'k':2}), 'quad' : (Spline, {'k':2}),
+                    'Cubic' : (Spline, {'k':3}), 'cubic' : (Spline, {'k':3}),
+                    'Quartic' : (Spline, {'k':4}), 'quartic' : (Spline, {'k':4}),
+                    'Quar' : (Spline, {'k':4}), 'quar' : (Spline, {'k':4}),
+                    'Quintic' : (Spline, {'k':5}), 'quintic' : (Spline, {'k':5}),
+                    'Quin' : (Spline, {'k':5}), 'quin' : (Spline, {'k':5})
+                }
+
+def interp1d(x, y, new_x, 
+                    interp = 'linear', extrap_low = NaN, extrap_high = NaN,
+                    interpkw = {}, lowkw = {}, highkw ={},
                     bad_data = None):
+    # FIXME : all y to be multi-dimensional
+    # FIXME : update the doc string to match that of Interpolate1d
     """ A function for interpolation of 1D data.
         
         Parameters
@@ -30,7 +51,7 @@
         Optional Arguments
         -------------------
         
-        kind -- Usu. function or string.  But can be any type.
+        interp -- Usu. function or string.  But can be any type.
             Specifies the type of extrapolation to use for values within
             the range of x.  If a string is passed, it will look for an object
             or function with that name and call it when evaluating.  If 
@@ -38,17 +59,17 @@
             If nothing else, assumes the argument is intended as a value
             to be returned for all arguments.  Defaults to linear interpolation.
             
-        low (high) -- same as for kind
-            Same options as for 'kind'.  Defaults to returning numpy.NaN ('not 
+        low (high) -- same as for interp
+            Same options as for 'interp'.  Defaults to returning numpy.NaN ('not 
             a number') for all values outside the range of x.
         
-        remove_bad_data -- bool
-            indicates whether to remove bad data.
+        interpkw -- dictionary
+            If 
             
         bad_data -- list
             List of values (in x or y) which indicate unacceptable data. All points
             that have x or y value in missing_data will be removed before
-            any interpolatin is performed if remove_bad_data is true.
+            any interpolatin is performed if bad_data is not None.
             
             numpy.NaN is always considered bad data.
             
@@ -78,6 +99,9 @@
                                 interp = interp,
                                 extrap_low = extrap_low,
                                 extrap_high = extrap_high,
+                                interpkw = interpkw,
+                                lowkw = lowkw,
+                                highkw = highkw,
                                 bad_data = bad_data
                                 )(new_x)
 
@@ -87,70 +111,51 @@
         Parameters
         -----------
             
-        x -- list or 1D NumPy array
-            x includes the x-values for the data set to
-            interpolate from.  It must be sorted in
-            ascending order.
+            x -- list or 1D NumPy array
+                x includes the x-values for the data set to
+                interpolate from.  It must be sorted in
+                ascending order.
+                    
+            y -- list or 1D NumPy array
+                y includes the y-values for the data set  to
+                interpolate from.  Note that 2-dimensional
+                y is not supported.
                 
-        y -- list or 1D NumPy array
-            y includes the y-values for the data set  to
-            interpolate from.  Note that 2-dimensional
-            y is not supported.
-                
         Optional Arguments
         -------------------
         
-        kind -- Usu. string or function.  But can be any type.
-            Specifies the type of interpolation to use for values within
-            the range of x.
-            
-            If a string is passed, it will look for an object
-            or function with that name and call it when evaluating.
-            This is the primary mode of operation.  See below for list
-            of acceptable strings.
-            
-            By default, linear interpolation is used.
-            
-            Other options are also available:
-            
-                If a callable class is passed, it is assumed to have format
-                    instance = Class(x, y).
-                It is instantiated and used for interpolation when the instance
-                of Interpolate1d is called.
+            interp -- Usually a string.  But can be any type.
+                Specifies the type of interpolation to use for values within
+                the range of x.
                 
-                If a callable object with method "init_xy" or "set_xy" is
-                passed, that method will be used to set x and y, and the
-                object will be called during interpolation.
+                By default, linear interpolation is used.
                 
-                If a function is passed, it will be called when interpolating.
-                It is assumed to have the form 
-                    newy = kind(x, y, newx), 
-                where x, y, newx, and newy are all numpy arrays.
+                See below for details on other options.
                 
-                A primitive type which is not a string signifies a function
-                which is identically that value (e.g. val and 
-                lambda x, y, newx : val are equivalent).
+            extrap_low  -- same as for kind
+                How to extrapolate values for inputs below the range of x.
+                Same options as for 'kind'.  Defaults to returning numpy.NaN ('not 
+                a number') for all values below the range of x.
+                
+            extrap_high  -- same as for kind
+                How to extrapolate values for inputs above the range of x.
+                Same options as for 'kind'.  Defaults to returning numpy.NaN ('not 
+                a number') for all values above the range of x.
+                
+            bad_data -- list
+                List of numerical values (in x or y) which indicate unacceptable data. 
+                
+                If bad_data is not None (its default), all points whose x or y coordinate is in
+                bad_data, OR ones of whose coordinates is NaN, will be removed.
+                
+            interpkw -- dictionary
+                If interp is set to a function, class or callable object, this contains
+                additional keywords.
+                
+            lowkw (highkw) -- dictionary
+                like interpkw, but for extrap_low and extrap_high
+                
             
-        low  -- same as for kind
-            How to extrapolate values for inputs below the range of x.
-            Same options as for 'kind'.  Defaults to returning numpy.NaN ('not 
-            a number') for all values below the range of x.
-            
-        high  -- same as for kind
-            How to extrapolate values for inputs above the range of x.
-            Same options as for 'kind'.  Defaults to returning numpy.NaN ('not 
-            a number') for all values above the range of x.
-        
-        remove_bad_data -- bool
-            indicates whether to remove bad data points from x and y.
-            
-        bad_data -- list
-            List of values (in x or y) which indicate unacceptable data. All points
-            that have x or y value in missing_data will be removed before
-            any interpolatin is performed if remove_bad_data is true.
-            
-            numpy.NaN is always considered bad data.
-            
         Some Acceptable Input Strings
         ------------------------
         
@@ -164,31 +169,72 @@
             "cubic" -- spline interpolation order 3
             "quartic" -- spline interpolation order 4
             "quintic" -- spline interpolation order 5
+            
+        Other options for interp, extrap_low, and extrap_high
+        ---------------------------------------------------
         
-        Examples
+            If you choose to use a non-string argument, you must
+            be careful to use correct formatting.
+            
+            If a function is passed, it will be called when interpolating.
+            It is assumed to have the form 
+                newy = interp(x, y, newx, **kw), 
+            where x, y, newx, and newy are all numpy arrays.
+            
+            If a callable class is passed, it is assumed to have format
+                instance = Class(x, y, **kw).
+            which can then be called by
+                new_y = instance(new_x)
+            
+            If a callable object with method "init_xy" or "set_xy" is
+            passed, that method will be used to set x and y as follows
+                instance.set_xy(x, y, **kw)
+            and the object will be called during interpolation.
+                new_y = instance(new_x)
+            If the "init_xy" and "set_xy" are not present, it will be called as
+                new_y = argument(new_x)
+                
+            A primitive type which is not a string signifies a function
+            which is identically that value (e.g. val and 
+            lambda x, y, newx : val are equivalent).
+            
+        Example
         ---------
         
             >>> import numpy
             >>> from interpolate1d import Interpolate1d
             >>> x = range(5)        # note list is permitted
             >>> y = numpy.arange(5.)
-            >>> new_x = [.2, 2.3, 5.6]
+            >>> new_x = [.2, 2.3, 5.6, 7.0]
             >>> interp_func = Interpolate1d(x, y)
             >>> interp_fuc(new_x)
             array([.2, 2.3, 5.6, NaN])
+            
     """
     # FIXME: more informative descriptions of sample arguments
     # FIXME: examples in doc string
     # FIXME : Allow copying or not of arrays.  non-copy + remove_bad_data should flash 
     #           a warning (esp if we interpolate missing values), but work anyway.
     
-    def __init__(self, x, y, interp = 'linear', extrap_low = NaN, extrap_high = NaN,
+    def __init__(self, x, y, 
+                        interp = 'linear', 
+                        extrap_low = NaN, 
+                        extrap_high = NaN,
+                        interpkw = {},
+                        lowkw = {},
+                        highkw = {},
                         bad_data = None):
         # FIXME: don't allow copying multiple times.
         # FIXME : allow no copying, in case user has huge dataset
         
         # remove bad data, is there is any
         if bad_data is not None:
+            try:
+                sum_of_bad_data = sum(bad_data)
+            except:
+                raise TypeError, "bad_data must be either None \
+                        or a list of numerical types"
+            
             x, y = self._remove_bad_data(x, y, bad_data)
         
         # check acceptable size and dimensions
@@ -204,9 +250,9 @@
         self._init_xy(x, y)
         
         # store interpolation functions for each range
-        self.interp = self._init_interp_method(interp)
-        self.extrap_low = self._init_interp_method(extrap_low)
-        self.extrap_high = self._init_interp_method(extrap_high)
+        self.interp = self._init_interp_method(interp, interpkw)
+        self.extrap_low = self._init_interp_method(extrap_low, lowkw)
+        self.extrap_high = self._init_interp_method(extrap_high, highkw)
 
     def _init_xy(self, x, y):
         
@@ -216,21 +262,23 @@
         self._x = atleast_1d_and_contiguous(x, self._xdtype).copy()
         self._y = atleast_1d_and_contiguous(y, self._ydtype).copy()
 
-    def _remove_bad_data(self, x, y, bad_data = [None, NaN]):
+    def _remove_bad_data(self, x, y, bad_data = []):
         """ removes data points whose x or y coordinate is
             either in bad_data or is a NaN.
         """
         # FIXME : In the future, it may be good to just replace the bad points with good guesses.
         #       Especially in generalizing the higher dimensions
         # FIXME : This step is very inefficient because it iterates over the array
-        mask = np.array([  (xi not in bad_data) and (not np.isnan(xi)) and \
-                                    (y[i] not in bad_data) and (not np.isnan(y[i])) \
-                                for i, xi in enumerate(x) ])
-        x = x[mask]
-        y = y[mask]
+        
+        bad_data_mask = np.isnan(x) | np.isnan(y)
+        for bad_num in bad_data:
+              bad_data_mask =  bad_data_mask | (x==bad_num) | (y==bad_num)
+              
+        x = x[~bad_data_mask]
+        y = y[~bad_data_mask]
         return x, y
         
-    def _init_interp_method(self, interp_arg):
+    def _init_interp_method(self, interp_arg, kw):
         """
             returns the interpolating function specified by interp_arg.
         """
@@ -241,60 +289,46 @@
         from inspect import isclass, isfunction
         
         # primary usage : user passes a string indicating a known function
-        if interp_arg in ['linear', 'logarithmic', 'block', 'block_average_above']:
-            # string used to indicate interpolation method,  Select appropriate function
-            func = {'linear':linear, 'logarithmic':logarithmic, 'block':block, \
-                        'block_average_above':block_average_above}[interp_arg]
-            result = lambda new_x : func(self._x, self._y, new_x)
-        elif interp_arg in ['Spline', 'spline']:
-            # use the Spline class from fitpack_wrapper
-            # k = 3 unless otherwise specified
-            result = Spline(self._x, self._y)
-        elif interp_arg in ['Quadratic', 'quadratic', 'Quad', 'quad', \
-                                'Cubic', 'cubic', \
-                                'Quartic', 'quartic', 'Quar', 'quar',\
-                                'Quintic', 'quintic', 'Quin', 'quin']:
-            # specify specific kinds of splines
-            if interp_arg in ['Quadratic', 'quadratic', 'Quad', 'quad']:
-                result = Spline(self._x, self._y, k=2)
-            elif interp_arg in ['Cubic', 'cubic']:
-                result = Spline(self._x, self._y, k=3)
-            elif interp_arg in ['Quartic', 'quartic', 'Quar', 'quar']:
-                result = Spline(self._x, self._y, k=4)
-            elif interp_arg in ['Quintic', 'quintic', 'Quin', 'quin']:
-                result = Spline(self._x, self._y, k=5)
-        elif isinstance(interp_arg, basestring):
-            raise TypeError, "input string %s not valid" % interp_arg
+        if isinstance(interp_arg, basestring):
+            interpolator, kw = dict_of_interp_types.setdefault(interp_arg, (None, {}) )
+            
+            if interpolator is None: 
+                raise TypeError, "input string %s not valid" % interp_arg
+        else:
+            interpolator = interp_arg
         
-        # secondary usage : user passes a callable class
-        elif isclass(interp_arg) and hasattr(interp_arg, '__call__'):
-            if hasattr(interp_arg, 'init_xy'):
-                result = interp_arg()
-                result.init_xy(self._x, self._y)
-            elif hasattr(interp_arg, 'set_xy'):
-                result = interp_arg()
-                result.set_xy(self._x, self._y)
-            else:
-                result = interp_arg(x, y)
+        # interpolator is a callable : function, class, or instance of class
+        if hasattr(interpolator, '__call__'):
+            # function
+            if isfunction(interpolator):
+                result = lambda newx : interpolator(self._x, self._y, newx, **kw)
                 
-        # user passes an instance of a callable class which has yet
-        # to have its x and y initialized.
-        elif hasattr(interp_arg, 'init_xy') and hasattr(interp_arg, '__call__'):
-            result = interp_arg
-            result.init_xy(self._x, self._y)
-        elif hasattr(interp_arg, 'set_xy') and hasattr(interp_arg, '__call__'):
-            result = interp_arg
-            result.set_xy(self._x, self._y)
+            # callable class 
+            elif isclass(interpolator):
+                if hasattr(interpolator, 'set_xy'):
+                    result = interpolator(**kw)
+                    result.set_xy(self._x, self._y)
+                if hasattr(interpolator, 'init_xy'):
+                    result = interpolator(**kw)
+                    result.init_xy(self._x, self._y)
+                else:
+                    result = interpolator(self._x, self._y, **kw)
                 
-        # user passes a function to be called
-        # Assume function has form of f(x, y, newx)
-        elif isfunction(interp_arg):
-            result = lambda new_x : interp_arg(self._x, self._y, new_x)
-        
-        # default : user has passed a default value to always be returned
+            # instance of callable class
+            else:
+                if hasattr(interpolator, 'init_xy'):
+                    result = interpolator
+                    result.init_xy(self._x, self._y, **kw)
+                elif hasattr(interpolator, 'set_xy'):
+                    result = interpolator
+                    result.set_xy(self._x, self._y, **kw)
+                else:
+                    result = interpolator
+            
+        # non-callable : user has passed a default value to always be returned
         else:
             result = np.vectorize(lambda new_x : interp_arg)
-            
+        
         return result
 
     def __call__(self, newx):
@@ -308,8 +342,13 @@
         #   waste of time, but ok for the time being.
         
         # if input is scalar or 0-dimemsional array, output will be scalar
-        input_is_scalar = np.isscalar(newx) or (isinstance(newx, type(np.array([1.0]))) and np.shape(newx) == ())
+        input_is_scalar = np.isscalar(newx) or \
+                                    (
+                                        isinstance(  newx , np.ndarray  ) and 
+                                        np.shape(newx) == ()
+                                    )
         
+        # make 
         newx_array = atleast_1d_and_contiguous(newx)
         
         # masks indicate which elements fall into which interpolation region
@@ -317,6 +356,9 @@
         high_mask = newx_array>self._x[-1]
         interp_mask = (~low_mask) & (~high_mask)
                 
+        type(newx_array[low_mask])
+                
+                
         # use correct function for x values in each region
         if len(newx_array[low_mask]) == 0: new_low=np.array([])  # FIXME : remove need for if/else.
                                                                             # if/else is a hack, since vectorize is failing
@@ -333,6 +375,7 @@
                                                                                           # Would be nice to say result = zeros(dtype=?)
                                                                                           # and fill in
         
+        # convert to scalar if scalar was passed in
         if input_is_scalar:
             result = float(result_array)
         else:



More information about the Scipy-svn mailing list