# [SciPy-User] multidimensional least squares fitting

Wed Nov 9 11:12:22 CST 2011

Hi everyone,

I'd like to do some rather simple multidimensional curve fitting.
Simple, because usually it's only a plane, or a weak 2nd order
surface.

Here's the same question which I tried to follow, but I have no clue
how to feed the 2D arrays into leastsq():
http://stackoverflow.com/questions/529184/simple-multidimensional-curve-fitting

Likely I am just missing a small piece of information, and I'd be
happy to get a clue :)

and here's some code to demonstrate what I want and to get started,
Daniel

import pylab
import scipy
import scipy.optimize

from mpl_toolkits.mplot3d import Axes3D

#import sys,os,platform
#if platform.system() == 'Windows':
#    home = os.environ['HOMESHARE']
#elif platform.system() == 'Linux':
#    home = os.environ['HOME']
#sys.path.append(home + '/python')
#sys.path.append(home + '/11_PythonWork')
#import pylabSettings

##******************************************************************************
##******************************************************************************

'''
f = p0 + p1*x + p2*y
'''

##------------------------------------------------------------------------------
def __residual(params, f, x, y):
'''
Define fit function;
Return residual error.
'''
p0, p1, p2 = params
return p0 + p1*x + p2*y - f

## load raw data (=create some dummy data):
dataX = scipy.arange(0,11,1)
dataY = dataX/10.
dataZ = 0.5 + 1.1*dataX + 1.5*dataY
dataXX, dataYY = scipy.meshgrid(dataX,dataY)
dataZZ = 0.5 + 1.1*dataXX + 1.5*dataYY

## plot data
pylab.close('all')
fig = pylab.figure()
ax = Axes3D(fig)
ax.plot_wireframe(dataXX, dataYY, dataZZ)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
pylab.show()

## guess initial values for parameters
p0 = [0., 1., 1.]
print __residual(p0, dataZZ, dataXX, dataYY)

## works but is not 2D!
p1, p_cov = scipy.optimize.leastsq(__residual, x0=p0, args=(dataZ,
dataX, dataY))
print p1

## doesn't work :()
p1, p_cov = scipy.optimize.leastsq(__residual, x0=p0, args=(dataZZ,
dataXX, dataYY))
print p1