[Numpy-svn] r5291 - in trunk/numpy/lib: . tests

numpy-svn@scip... numpy-svn@scip...
Tue Jun 17 15:08:29 CDT 2008


Author: oliphant
Date: 2008-06-17 15:08:28 -0500 (Tue, 17 Jun 2008)
New Revision: 5291

Modified:
   trunk/numpy/lib/function_base.py
   trunk/numpy/lib/tests/test_function_base.py
Log:
Fix piecewise to handle 0-d inputs.

Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2008-06-17 13:06:08 UTC (rev 5290)
+++ trunk/numpy/lib/function_base.py	2008-06-17 20:08:28 UTC (rev 5291)
@@ -574,13 +574,32 @@
         n += 1
     if (n != n2):
         raise ValueError, "function list and condition list must be the same"
+    zerod = False
+    # This is a hack to work around problems with NumPy's 
+    #  handling of 0-d arrays and boolean indexing with 
+    #  numpy.bool_ scalars
+    if x.ndim == 0:
+        x = x[None]
+        zerod = True
+        newcondlist = []
+        for k in range(n):
+            if condlist[k].ndim == 0:
+                condition = condlist[k][None]
+            else:
+                condition = condlist[k]
+            newcondlist.append(condition)
+        condlist = newcondlist
     y = empty(x.shape, x.dtype)
     for k in range(n):
         item = funclist[k]
         if not callable(item):
             y[condlist[k]] = item
         else:
-            y[condlist[k]] = item(x[condlist[k]], *args, **kw)
+            vals = x[condlist[k]]
+            if vals.size > 0:
+                y[condlist[k]] = item(vals, *args, **kw)
+    if zerod:
+        y = y.squeeze()
     return y
 
 def select(condlist, choicelist, default=0):

Modified: trunk/numpy/lib/tests/test_function_base.py
===================================================================
--- trunk/numpy/lib/tests/test_function_base.py	2008-06-17 13:06:08 UTC (rev 5290)
+++ trunk/numpy/lib/tests/test_function_base.py	2008-06-17 20:08:28 UTC (rev 5291)
@@ -616,5 +616,12 @@
     for i in range(len(desired)):
         assert_array_equal(res[i],desired[i])
 
+class TestPiecewise(TestCase):
+    def test_0d(self):
+        x = array(3)
+        y = piecewise(x, x>3, [4, 0])
+        assert y.ndim == 0
+        assert y == 0
+
 if __name__ == "__main__":
     nose.run(argv=['', __file__])



More information about the Numpy-svn mailing list