[Scipy-svn] r2690 - in trunk/Lib/sandbox: . rbf rbf/tests spline/tests

scipy-svn@scip... scipy-svn@scip...
Thu Feb 8 06:05:31 CST 2007


Author: jtravs
Date: 2007-02-08 06:05:18 -0600 (Thu, 08 Feb 2007)
New Revision: 2690

Added:
   trunk/Lib/sandbox/rbf/
   trunk/Lib/sandbox/rbf/README.txt
   trunk/Lib/sandbox/rbf/__init__.py
   trunk/Lib/sandbox/rbf/info.py
   trunk/Lib/sandbox/rbf/rbf.py
   trunk/Lib/sandbox/rbf/setup.py
   trunk/Lib/sandbox/rbf/tests/
   trunk/Lib/sandbox/rbf/tests/example.py
   trunk/Lib/sandbox/rbf/tests/test_rbf.py
Modified:
   trunk/Lib/sandbox/setup.py
   trunk/Lib/sandbox/spline/tests/test_fitpack.py
Log:
Added new rbf package to sandbox


Added: trunk/Lib/sandbox/rbf/README.txt
===================================================================
--- trunk/Lib/sandbox/rbf/README.txt	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/README.txt	2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,8 @@
+This package uses radial basis functions for n-dimensional
+smoothing/interpolation of scattered data
+
+It is closely based on the MAtlab code by Alex Chirokov found at:
+
+http://www.mathworks.com/matlabcentral/fileexchange/loadFile.do?objectId=10056&objectType=FILE
+
+John Travers
\ No newline at end of file

Added: trunk/Lib/sandbox/rbf/__init__.py
===================================================================
--- trunk/Lib/sandbox/rbf/__init__.py	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/__init__.py	2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,11 @@
+#
+# rbf - Radial Basis Functions
+#
+
+from info import __doc__
+
+from rbf import *
+
+__all__ = filter(lambda s:not s.startswith('_'),dir())
+from numpy.testing import NumpyTest
+test = NumpyTest().test
\ No newline at end of file

Added: trunk/Lib/sandbox/rbf/info.py
===================================================================
--- trunk/Lib/sandbox/rbf/info.py	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/info.py	2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,9 @@
+"""
+Radial Basis Functions
+===================
+
+rbf - Radial basis functions for interpolation/smoothing.
+
+"""
+
+postpone_import = 1

Added: trunk/Lib/sandbox/rbf/rbf.py
===================================================================
--- trunk/Lib/sandbox/rbf/rbf.py	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/rbf.py	2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,126 @@
+#!/usr/bin/env python
+"""
+rbf - Radial basis functions for interpolation/smoothing scattered Nd data.
+
+Written by John Travers <jtravs@gmail.com>, February 2007
+Based closely on Matlab code by Alex Chirokov
+
+Permission to use, modify, and distribute this software is given under the
+terms of the SciPy (BSD style) license.  See LICENSE.txt that came with
+this distribution for specifics.
+
+NO WARRANTY IS EXPRESSED OR IMPLIED.  USE AT YOUR OWN RISK.
+
+"""
+
+import scipy as s
+import scipy.linalg
+
+class Rbf(object):
+    """ A class for radial basis function approximation/interpolation of
+        n-dimensional scattered data.
+    """
+    def __init__(self,x,y, function='multiquadrics', constant=None, smooth=0):
+        """ Constructor for Rbf class.
+
+            Inputs:
+                x   (dim, n) array of coordinates for the nodes
+                y   (n,) array of values at the nodes
+                function    the radial basis function
+                            'linear', 'cubic' 'thinplate', 'multiquadrics'
+                            or 'gaussian', default is 'multiquadrics'
+                constant    adjustable constant for gaussian or multiquadrics
+                            functions - defaults to approximate average distance
+                            between nodes (which is a good start)
+                smooth      values greater than zero increase the smoothness
+                            of the approximation. 
+                            0 is for interpolation (default), the function will
+                            always go through the nodal points in this case.
+
+            Outputs: None
+        """
+        if len(x.shape) == 1:
+            nxdim = 1
+            nx = x.shape[0]
+        else:
+            (nxdim, nx)=x.shape
+        if len(y.shape) == 1:
+            nydim = 1
+            ny = y.shape[0]
+        else:
+            (nydim, ny)=y.shape
+        x.shape = (nxdim, nx)
+        y.shape = (nydim, ny)
+        if nx != ny:
+            raise ValueError, 'x and y should have the same number of points'
+        if nydim != 1:
+            raise ValueError, 'y should be a length n vector'
+        self.x = x
+        self.y = y
+        self.function = function
+        if (constant==None 
+            and ((function == 'multiquadrics') or (function == 'gaussian'))):
+            # approx. average distance between the nodes
+            constant = (s.product(x.T.max(0)-x.T.min(0),axis=0)/nx)**(1/nxdim)
+        self.constant = constant
+        self.smooth = smooth
+        if self.function == 'linear':
+            self.phi = lambda r: r
+        elif self.function == 'cubic':
+            self.phi = lambda r: r*r*r
+        elif self.function == 'multiquadrics':
+            self.phi = lambda r: s.sqrt(1.0+r*r/(self.constant*self.constant))
+        elif self.function == 'thinplate':
+            self.phi = lambda r: r*r*s.log(r+1)
+        elif self.function == 'gaussian':
+            self.phi = lambda r: s.exp(-0.5*r*r/(self.rbfconst*self.constant))
+        else:
+            raise ValueError, 'unkown function'
+        A = self._rbf_assemble()
+        b=s.r_[y.T, s.zeros((nxdim+1, 1), float)]
+        self.coeff = s.linalg.solve(A,b)
+
+    def __call__(self, xi):
+        """ Evaluate the radial basis function approximation at points xi.
+
+            Inputs:
+                xi  (dim, n) array of coordinates for the points to evaluate at
+
+            Outputs:
+                y   (n,) array of values at the points xi
+        """
+        if len(xi.shape) == 1:
+            nxidim = 1
+            nxi = xi.shape[0]
+        else:
+            (nxidim, nxi)=xi.shape
+        xi.shape = (nxidim, nxi)
+        (nxdim, nx) = self.x.shape
+        if nxdim != nxidim:
+            raise ValueError, 'xi should have the same number of rows as an' \
+                              ' array used to create RBF interpolation'
+        f = s.zeros(nxi, float)
+        r = s.zeros(nx, float)
+        for i in range(nxi):
+            st=0.0
+            r = s.dot(xi[:,i,s.newaxis],s.ones((1,nx))) - self.x
+            r = s.sqrt(sum(r*r))
+            st = self.coeff[nx,:] + s.sum(self.coeff[0:nx,:].flatten()*self.phi(r))
+            for k in range(nxdim):
+                st=st+self.coeff[k+nx+1,:]*xi[k,i]
+            f[i] = st
+        return f
+
+    def _rbf_assemble(self):
+        (nxdim, nx)=self.x.shape
+        A=s.zeros((nx,nx), float)
+        for i in range(nx):
+            for j in range(i+1):
+                r=s.linalg.norm(self.x[:,i]-self.x[:,j])
+                temp=self.phi(r)
+                A[i,j]=temp
+                A[j,i]=temp
+            A[i,i] = A[i,i] - self.smooth
+        P = s.c_[s.ones((nx,1), float), self.x.T]
+        A = s.r_[s.c_[A, P], s.c_[P.T, s.zeros((nxdim+1,nxdim+1), float)]]
+        return A

Added: trunk/Lib/sandbox/rbf/setup.py
===================================================================
--- trunk/Lib/sandbox/rbf/setup.py	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/setup.py	2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+
+import os
+
+def configuration(parent_package='',top_path=None):
+    from numpy.distutils.misc_util import Configuration
+ 
+    config = Configuration('rbf', parent_package, top_path)
+
+    config.add_data_dir('tests')
+
+    return config
+
+if __name__ == '__main__':
+    from numpy.distutils.core import setup
+    setup(**configuration(top_path='').todict())
\ No newline at end of file

Added: trunk/Lib/sandbox/rbf/tests/example.py
===================================================================
--- trunk/Lib/sandbox/rbf/tests/example.py	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/tests/example.py	2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,52 @@
+import scipy as s
+import scipy.interpolate
+
+from scipy.sandbox.rbf import Rbf
+
+import matplotlib
+matplotlib.use('Agg')
+import pylab as p
+
+# 1d tests - setup data
+x = s.linspace(0,10,9)
+y = s.sin(x) 
+xi = s.linspace(0,10,101)
+
+# use interpolate methods
+ius = s.interpolate.InterpolatedUnivariateSpline(x,y)
+yi = ius(xi)
+p.subplot(2,1,1)
+p.plot(x,y,'o',xi,yi, xi, s.sin(xi),'r')
+p.title('Interpolation using current scipy fitpack2')
+
+# use RBF method
+rbf = Rbf(x, y)
+fi = rbf(xi)
+p.subplot(2,1,2)
+p.plot(x,y,'bo',xi.flatten(),fi.flatten(),'g',xi.flatten(),
+                                                    s.sin(xi.flatten()),'r')
+p.title('RBF interpolation - multiquadrics')
+p.savefig('rbf1dtest.png')
+p.close()
+
+# 2-d tests - setup scattered data
+x = s.rand(50,1)*4-2
+y = s.rand(50,1)*4-2
+z = x*s.exp(-x**2-y**2)
+ti = s.linspace(-2.0,2.0,81)
+(XI,YI) = s.meshgrid(ti,ti)
+
+# use RBF
+rbf = Rbf(s.c_[x.flatten(),y.flatten()].T,z.T,constant=2)
+ZI = rbf(s.c_[XI.flatten(), YI.flatten()].T)
+ZI.shape = XI.shape
+
+# plot the result
+from enthought.tvtk.tools import mlab
+f=mlab.figure(browser=False)
+su=mlab.Surf(XI,YI,ZI,ZI,scalar_visibility=True)
+f.add(su)
+su.lut_type='blue-red'
+f.objects[0].axis.z_label='value'
+pp = mlab.Spheres(s.c_[x.flatten(), y.flatten(), z.flatten()],radius=0.03)
+f.add(pp)
\ No newline at end of file


Property changes on: trunk/Lib/sandbox/rbf/tests/example.py
___________________________________________________________________
Name: svn:executable
   + *

Added: trunk/Lib/sandbox/rbf/tests/test_rbf.py
===================================================================
--- trunk/Lib/sandbox/rbf/tests/test_rbf.py	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/tests/test_rbf.py	2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+# Created by John Travers, February 2007
+""" Test functions for rbf module """
+
+from numpy.testing import *
+import numpy as n
+
+set_package_path()
+from rbf.rbf import Rbf
+restore_path()
+
+class test_Rbf1D(NumpyTestCase):
+    def check_multiquadrics(self):
+        x = n.linspace(0,10,9)
+        y = n.sin(x) 
+        rbf = Rbf(x, y)
+        yi = rbf(x)
+        assert_array_almost_equal(y.flatten(), yi)
+
+class test_Rbf2D(NumpyTestCase):
+    def check_multiquadrics(self):
+        x = n.random.rand(50,1)*4-2
+        y = n.random.rand(50,1)*4-2
+        z = x*n.exp(-x**2-y**2)
+        rbf = Rbf(n.c_[x.flatten(),y.flatten()].T,z.T,constant=2)
+        zi = rbf(n.c_[x.flatten(), y.flatten()].T)
+        zi.shape = x.shape
+        assert_array_almost_equal(z, zi)
+
+if __name__ == "__main__":
+    NumpyTest().run()
\ No newline at end of file

Modified: trunk/Lib/sandbox/setup.py
===================================================================
--- trunk/Lib/sandbox/setup.py	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/setup.py	2007-02-08 12:05:18 UTC (rev 2690)
@@ -81,6 +81,9 @@
 
     # New spline package (based on scipy.interpolate)
     #config.add_subpackage('spline')
+    
+    # Radial basis functions package
+    #config.add_subpackage('rbf')
 
     return config
 

Modified: trunk/Lib/sandbox/spline/tests/test_fitpack.py
===================================================================
--- trunk/Lib/sandbox/spline/tests/test_fitpack.py	2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/spline/tests/test_fitpack.py	2007-02-08 12:05:18 UTC (rev 2690)
@@ -142,3 +142,6 @@
                                                                      decimal=1)
                 assert_almost_equal(0.0, 
                             around(abs(splev(uv[0],tck)-f(uv[0])),2),decimal=1)
+
+if __name__ == "__main__":
+    NumpyTest().run()
\ No newline at end of file



More information about the Scipy-svn mailing list