[Scipy-svn] r6215 - trunk/scipy/interpolate

scipy-svn@scip... scipy-svn@scip...
Sun Feb 7 01:13:04 CST 2010

```Author: oliphant
Date: 2010-02-07 01:13:04 -0600 (Sun, 07 Feb 2010)
New Revision: 6215

Modified:
trunk/scipy/interpolate/rbf.py
Log:
Add ability to use arbitrary basis function to Rbf constructor for radial basis function interpolation.

Modified: trunk/scipy/interpolate/rbf.py
===================================================================
--- trunk/scipy/interpolate/rbf.py	2010-02-05 04:31:54 UTC (rev 6214)
+++ trunk/scipy/interpolate/rbf.py	2010-02-07 07:13:04 UTC (rev 6215)
@@ -3,6 +3,7 @@
Written by John Travers <jtravs@gmail.com>, February 2007
Based closely on Matlab code by Alex Chirokov
Additional, large, improvements by Robert Hetland
+Some additional alterations by Travis Oliphant

Permission to use, modify, and distribute this software is given under the
terms of the SciPy (BSD style) license.  See LICENSE.txt that came with
@@ -42,10 +43,11 @@
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

-from numpy import (sqrt, log, asarray, newaxis, all, dot, float64, exp, eye,
-                   isnan, float_)
+from numpy import (sqrt, log, asarray, newaxis, all, dot, exp, eye,
+                   float_)
from scipy import linalg

+
class Rbf(object):
"""
Rbf(*args)
@@ -58,18 +60,22 @@
*args : arrays
x, y, z, ..., d, where x, y, z, ... are the coordinates of the nodes
and d is the array of values at the nodes
-    function : str, optional
+    function : str or callable, optional
The radial basis function, based on the radius, r, given by the norm
(defult is Euclidean distance); the default is 'multiquadric'::

'multiquadric': sqrt((r/self.epsilon)**2 + 1)
-            'inverse multiquadric': 1.0/sqrt((r/self.epsilon)**2 + 1)
+            'inverse': 1.0/sqrt((r/self.epsilon)**2 + 1)
'gaussian': exp(-(r/self.epsilon)**2)
'linear': r
'cubic': r**3
'quintic': r**5
-            'thin-plate': r**2 * log(r)
+            'thin_plate': r**2 * log(r)

+        If callable, then it must take 2 arguments (self, r).  The epsilon parameter
+        will be available as self.epsilon.  Other keyword arguments passed in will
+        be available as well.
+
epsilon : float, optional
Adjustable constant for gaussian or multiquadrics functions
- defaults to approximate average distance between nodes (which is
@@ -99,26 +105,67 @@
def _euclidean_norm(self, x1, x2):
return sqrt( ((x1 - x2)**2).sum(axis=0) )

-    def _function(self, r):
-        if self.function.lower() == 'multiquadric':
+    def _h_multiquadric(self, r):
return sqrt((1.0/self.epsilon*r)**2 + 1)
-        elif self.function.lower() == 'inverse multiquadric':
+    def _h_inverse_multiquadric(self, r):
return 1.0/sqrt((1.0/self.epsilon*r)**2 + 1)
-        elif self.function.lower() == 'gaussian':
+    def _h_gaussian(self, r):
return exp(-(1.0/self.epsilon*r)**2)
-        elif self.function.lower() == 'linear':
-            return r
-        elif self.function.lower() == 'cubic':
-            return r**3
-        elif self.function.lower() == 'quintic':
-            return r**5
-        elif self.function.lower() == 'thin-plate':
-            result = r**2 * log(r)
-            result[r == 0] = 0 # the spline is zero at zero
-            return result
-        else:
-            raise ValueError, 'Invalid basis function name'
+    def _h_linear(self, r):
+        return r
+    def _h_cubic(self, r):
+        return r**3
+    def _h_quintic(self, r):
+        return r**5
+    def _h_thin_plate(self, r):
+        result = r**2 * log(r)
+        result[r == 0] = 0 # the spline is zero at zero
+        return result

+    # Setup self._function and do smoke test on initial r
+    def _init_function(self, r):
+        if isinstance(self.function, str):
+           self.function = self.function.lower()
+           _mapped = {'inverse': 'inverse_multiquadric',
+                      'inverse multiquadric': 'inverse_multiquadric',
+                      'thin-plate': 'thin_plate'}
+           if self.function in _mapped:
+               self.function = _mapped[self.function]
+
+           func_name = "_h_" + self.function
+           if hasattr(self, func_name):
+               self._function = getattr(self, func_name)
+           else:
+               functionlist = [x[3:] for x in dir(self) if x.startswith('_h_')]
+               raise ValueError, "function must be a callable or one of ", \
+                   ", ".join(functionlist)
+           self._function = getattr(self, "_h_"+self.function)
+        elif callable(self.function):
+            import new
+            allow_one = False
+            if hasattr(self.function, 'func_code'):
+                val = self.function
+                allow_one = True
+            elif hasattr(self.function, "im_func"):
+                val = self.function.im_func
+            elif hasattr(self.function, "__call__"):
+                val = self.function.__call__.im_func
+            else:
+                raise ValueError, "Cannot determine number of arguments to function"
+
+            argcount = val.func_code.co_argcount
+            if allow_one and argcount == 1:
+                self._function = self.function
+            elif argcount == 2:
+                self._function = new.instancemethod(self.function, self, Rbf)
+            else:
+                raise ValueError, "Function argument must take 1 or 2 arguments."
+
+        a0 = self._function(r)
+        if a0.shape != r.shape:
+            raise ValueError, "Callable must take array and return array of the same shape"
+        return a0
+
def __init__(self, *args, **kwargs):
self.xi = asarray([asarray(a, dtype=float_).flatten()
for a in args[:-1]])
@@ -131,10 +178,17 @@
self.norm = kwargs.pop('norm', self._euclidean_norm)
r = self._call_norm(self.xi, self.xi)
self.epsilon = kwargs.pop('epsilon', r.mean())
-        self.function = kwargs.pop('function', 'multiquadric')
self.smooth = kwargs.pop('smooth', 0.0)

-        self.A = self._function(r) - eye(self.N)*self.smooth
+        self.function = kwargs.pop('function', self._h_multiquadric)
+
+        # attach anything left in kwargs to self
+        #  for use by any user-callable function or
+        #  to save on the object returned.
+        for item, value in kwargs.items():
+            setattr(self, item, value)
+
+        self.A = self._init_function(r) - eye(self.N)*self.smooth
self.nodes = linalg.solve(self.A, self.di)

def _call_norm(self, x1, x2):

```

More information about the Scipy-svn mailing list