# [SciPy-User] multidimensional least squares fitting

Wed Nov 9 11:21:57 CST 2011

```On Wed, Nov 9, 2011 at 12:12 PM, Daniel Mader
> Hi everyone,
> I'd like to do some rather simple multidimensional curve fitting.
> Simple, because usually it's only a plane, or a weak 2nd order
> surface.
> Here's the same question which I tried to follow, but I have no clue
> how to feed the 2D arrays into leastsq():
> http://stackoverflow.com/questions/529184/simple-multidimensional-curve-fitting

If you just want to estimate a multiple linear or polynomial function,
then optimize.leastsq is overkill, linalg is enough.

The premade solution is to use:

statsmodels.sourceforge.net/generated/scikits.statsmodels.regression.linear_model.OLS.html

Josef

> Likely I am just missing a small piece of information, and I'd be
> happy to get a clue :)
> and here's some code to demonstrate what I want and to get started,
> Daniel
> import pylab
> import scipy
> import scipy.optimize
> from mpl_toolkits.mplot3d import Axes3D
> '''
> f = p0 + p1*x + p2*y
> '''
> ##------------------------------------------------------------------------------
> def __residual(params, f, x, y):
>    '''
>    Define fit function;
>    Return residual error.
>    '''
>    p0, p1, p2 = params
>    return p0 + p1*x + p2*y - f
> ## load raw data (=create some dummy data):
> dataX = scipy.arange(0,11,1)
> dataY = dataX/10.
> dataZ = 0.5 + 1.1*dataX + 1.5*dataY
> dataXX, dataYY = scipy.meshgrid(dataX,dataY)
> dataZZ = 0.5 + 1.1*dataXX + 1.5*dataYY
> ## plot data
> pylab.close('all')
> fig = pylab.figure()
> ax = Axes3D(fig)
> ax.plot_wireframe(dataXX, dataYY, dataZZ)
> ax.set_xlabel('x')
> ax.set_ylabel('y')
> ax.set_zlabel('z')
> pylab.show()
> ## guess initial values for parameters
> p0 = [0., 1., 1.]
> print __residual(p0, dataZZ, dataXX, dataYY)
> ## works but is not 2D!
> p1, p_cov = scipy.optimize.leastsq(__residual, x0=p0, args=(dataZ,
> dataX, dataY))
> print p1
> ## doesn't work :()
> p1, p_cov = scipy.optimize.leastsq(__residual, x0=p0, args=(dataZZ,
> dataXX, dataYY))
> print p1
```