# [Scipy-svn] r4195 - in trunk/scipy/interpolate: . tests

scipy-svn@scip... scipy-svn@scip...
Sun Apr 27 23:59:05 CDT 2008

```Author: peridot
Date: 2008-04-27 23:59:03 -0500 (Sun, 27 Apr 2008)
New Revision: 4195

Modified:
trunk/scipy/interpolate/polyint.py
trunk/scipy/interpolate/tests/test_polyint.py
Log:
Fix bug introduced in r4181; also PiecewisePolynomial now correctly distinguishes between scalar values and vectors of length 1.

Modified: trunk/scipy/interpolate/polyint.py
===================================================================
--- trunk/scipy/interpolate/polyint.py	2008-04-27 13:29:06 UTC (rev 4194)
+++ trunk/scipy/interpolate/polyint.py	2008-04-28 04:59:03 UTC (rev 4195)
@@ -432,16 +432,20 @@
derivatives needed is odd, it will prefer the rightmost endpoint. If
not enough derivatives are available, an exception is raised.
"""
+        yi0 = np.asarray(yi[0])
+        if len(yi0.shape)==2:
+            self.vector_valued = True
+            self.r = yi0.shape[1]
+        elif len(yi0.shape)==1:
+            self.vector_valued = False
+            self.r = 1
+        else:
+            raise ValueError, "Each derivative must be a vector, not a higher-rank array"
+
self.xi = [xi[0]]
-        self.yi = [yi[0]]
+        self.yi = [yi0]
self.n = 1

-        try:
-            self.r = len(yi[0][0])
-        except TypeError:
-            self.r = 1
-
-        self.n = 1
self.direction = direction
self.orders = []
self.polynomials = []
@@ -468,7 +472,11 @@
assert n2<=len(y2)

xi = np.zeros(n)
-        yi = np.zeros((n,self.r))
+        if self.vector_valued:
+            yi = np.zeros((n,self.r))
+        else:
+            yi = np.zeros((n,))
+
xi[:n1] = x1
yi[:n1] = y1[:n1]
xi[n1:] = x2
@@ -488,19 +496,23 @@
a polynomial order, or instructions to use the highest
possible order
"""
+
+        yi = np.asarray(yi)
+        if self.vector_valued:
+            if (len(yi.shape)!=2 or yi.shape[1]!=self.r):
+                raise ValueError, "Each derivative must be a vector of length %d" % self.r
+        else:
+            if len(yi.shape)!=1:
+                raise ValueError, "Each derivative must be a scalar"
+
if self.direction is None:
self.direction = np.sign(xi-self.xi[-1])
elif (xi-self.xi[-1])*self.direction < 0:
raise ValueError, "x coordinates must be in the %d direction: %s" % (self.direction, self.xi)
+
self.xi.append(xi)
self.yi.append(yi)

-        for y in yi:
-            if np.shape(y) != (self.r,):
-                if self.r>1:
-                    raise ValueError, "Each derivative must be a vector of length %d" % self.r
-                else:
-                    raise ValueError, "Each derivative must be a scalar"

if order is None:
n1 = len(self.yi[-2])
@@ -558,7 +570,7 @@
x = np.asarray(x)
m = len(x)
pos = np.clip(np.searchsorted(self.xi, x) - 1, 0, self.n-2)
-            if self.r>1:
+            if self.vector_valued:
y = np.zeros((m,self.r))
else:
y = np.zeros(m)
@@ -611,7 +623,7 @@
x = np.asarray(x)
m = len(x)
pos = np.clip(np.searchsorted(self.xi, x) - 1, 0, self.n-2)
-            if self.r>1:
+            if self.vector_valued:
y = np.zeros((der,m,self.r))
else:
y = np.zeros((der,m))

Modified: trunk/scipy/interpolate/tests/test_polyint.py
===================================================================
--- trunk/scipy/interpolate/tests/test_polyint.py	2008-04-27 13:29:06 UTC (rev 4194)
+++ trunk/scipy/interpolate/tests/test_polyint.py	2008-04-28 04:59:03 UTC (rev 4195)
@@ -219,6 +219,13 @@
assert_array_equal(np.shape(P([0])), (1,3))
assert_array_equal(np.shape(P([0,1])), (2,3))

+    def test_shapes_vectorvalue_1d(self):
+        yi = np.multiply.outer(np.asarray(self.yi),np.arange(1))
+        P = PiecewisePolynomial(self.xi,yi,4)
+        assert_array_equal(np.shape(P(0)), (1,))
+        assert_array_equal(np.shape(P([0])), (1,1))
+        assert_array_equal(np.shape(P([0,1])), (2,1))
+
def test_shapes_vectorvalue_derivative(self):
P = PiecewisePolynomial(self.xi,np.multiply.outer(self.yi,np.arange(3)),4)
n = 4

```