[SciPy-User] multidimensional polynomial fit

josef.pktd@gmai... josef.pktd@gmai...
Sat Jun 12 14:43:08 CDT 2010

```On Sat, Jun 12, 2010 at 2:36 PM, Oscar Gerardo Lazo Arjona
<algebraicamente@gmail.com> wrote:
> Hello!
>
> Is there some way to get a polynomial fit to a set of n-tuples? I've got
> a set of 4-tuples: (x1,x2,x3,T), and i would like to get a polynomial
> T(x1,x2,x3).
>
> I've seen numpy.polyfit, but that doesn't work for multidimensional sets.
>
> If there is no method available, I would be willing to write the
> necessary code, just tell me how to get it included.

Assuming I understand correctly,  fitting the last variable to a
polynomial of the first three

depends on how many cross terms you want.

here is an example which restricts the powers in the cross-terms

>>> x = np.arange(5)[:,None]+ [0,10,100]
>>> x = x[:,::-1] #reverse for ndindex
>>> x
array([[100,  10,   0],
[101,  11,   1],
[102,  12,   2],
[103,  13,   3],
[104,  14,   4]])
>>> simplex = [ind for ind in np.ndindex(*[3]*x.shape[1]) if sum(ind)<=2]
>>> simplex
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0), (0, 1, 1), (0, 2, 0), (1,
0, 0), (1, 0, 1), (1, 1, 0), (2, 0, 0)]
>>> np.array([np.prod(x**ind,1) for  ind in simplex]).T
array([[    1,     0,     0,    10,     0,   100,   100,     0,  1000,
10000],
[    1,     1,     1,    11,    11,   121,   101,   101,  1111,
10201],
[    1,     2,     4,    12,    24,   144,   102,   204,  1224,
10404],
[    1,     3,     9,    13,    39,   169,   103,   309,  1339,
10609],
[    1,     4,    16,    14,    56,   196,   104,   416,  1456,
10816]])

>>> nobs = 100
>>> x0 = np.random.randn(nobs,3)
>>> x = np.array([np.prod(x0**ind,1) for  ind in simplex]).T
>>> y = x.sum(1) + 0.1*np.random.randn(nobs)
>>> y.shape
(100,)
>>> from scikits.statsmodels import OLS
>>> res = OLS(y, x).fit()
>>> res.params
array([ 1.02381284,  1.00619277,  0.99437357,  0.96839791,  1.00923175,
1.00342817,  0.99046168,  1.00125689,  0.99069758,  0.98808115])
>>> yest = res.model.predict(x)
>>> import matplotlib.pyplot as plt
>>> plt.plot(y, yest)

use of OLS can be replaced by np.linalg.lstsq

Josef

>
> thanks!
>
> Oscar
> _______________________________________________
> SciPy-User mailing list
> SciPy-User@scipy.org
> http://mail.scipy.org/mailman/listinfo/scipy-user
>
```