[Numpy-svn] r5414 - in branches/1.1.x/numpy/lib: . tests

numpy-svn@scip... numpy-svn@scip...
Tue Jul 15 03:01:38 CDT 2008


Author: stefan
Date: 2008-07-15 03:01:26 -0500 (Tue, 15 Jul 2008)
New Revision: 5414

Modified:
   branches/1.1.x/numpy/lib/function_base.py
   branches/1.1.x/numpy/lib/tests/test_function_base.py
Log:
Merge changeset 5359:5360 from trunk.


Modified: branches/1.1.x/numpy/lib/function_base.py
===================================================================
--- branches/1.1.x/numpy/lib/function_base.py	2008-07-15 07:38:00 UTC (rev 5413)
+++ branches/1.1.x/numpy/lib/function_base.py	2008-07-15 08:01:26 UTC (rev 5414)
@@ -563,8 +563,11 @@
     """
     x = asanyarray(x)
     n2 = len(funclist)
-    if not isinstance(condlist, type([])):
+    if isscalar(condlist) or \
+           not (isinstance(condlist[0], list) or
+                isinstance(condlist[0], ndarray)):
         condlist = [condlist]
+    condlist = [asarray(c, dtype=bool) for c in condlist]
     n = len(condlist)
     if n == n2-1:  # compute the "otherwise" condition.
         totlist = condlist[0]
@@ -573,8 +576,26 @@
         condlist.append(~totlist)
         n += 1
     if (n != n2):
-        raise ValueError, "function list and condition list must be the same"
-    y = empty(x.shape, x.dtype)
+        raise ValueError, "function list and condition list " \
+                          "must be the same"
+
+    zerod = False
+    # This is a hack to work around problems with NumPy's
+    #  handling of 0-d arrays and boolean indexing with
+    #  numpy.bool_ scalars
+    if x.ndim == 0:
+        x = x[None]
+        zerod = True
+        newcondlist = []
+        for k in range(n):
+            if condlist[k].ndim == 0:
+                condition = condlist[k][None]
+            else:
+                condition = condlist[k]
+            newcondlist.append(condition)
+        condlist = newcondlist
+
+    y = zeros(x.shape, x.dtype)
     for k in range(n):
         item = funclist[k]
         if not callable(item):
@@ -1072,7 +1093,7 @@
             self.__doc__ = pyfunc.__doc__
         else:
             self.__doc__ = doc
-        if isinstance(otypes, types.StringType):
+        if isinstance(otypes, str):
             self.otypes = otypes
             for char in self.otypes:
                 if char not in typecodes['All']:
@@ -1103,7 +1124,7 @@
             for arg in args:
                 newargs.append(asarray(arg).flat[0])
             theout = self.thefunc(*newargs)
-            if isinstance(theout, types.TupleType):
+            if isinstance(theout, tuple):
                 self.nout = len(theout)
             else:
                 self.nout = 1

Modified: branches/1.1.x/numpy/lib/tests/test_function_base.py
===================================================================
--- branches/1.1.x/numpy/lib/tests/test_function_base.py	2008-07-15 07:38:00 UTC (rev 5413)
+++ branches/1.1.x/numpy/lib/tests/test_function_base.py	2008-07-15 08:01:26 UTC (rev 5414)
@@ -613,6 +613,52 @@
         x = array([5+6j, 1+1j, 1+10j, 10, 5+6j])
         assert(all(unique(x) == [1+1j, 1+10j, 5+6j, 10]))
 
+
+class TestPiecewise(NumpyTestCase):
+    def check_simple(self):
+        # Condition is single bool list
+        x = piecewise([0, 0], [True, False], [1])
+        assert_array_equal(x, [1, 0])
+
+        # List of conditions: single bool list
+        x = piecewise([0, 0], [[True, False]], [1])
+        assert_array_equal(x, [1, 0])
+
+        # Conditions is single bool array
+        x = piecewise([0, 0], array([True, False]), [1])
+        assert_array_equal(x, [1, 0])
+
+        # Condition is single int array
+        x = piecewise([0, 0], array([1, 0]), [1])
+        assert_array_equal(x, [1, 0])
+
+        # List of conditions: int array
+        x = piecewise([0, 0], [array([1, 0])], [1])
+        assert_array_equal(x, [1, 0])
+
+
+        x = piecewise([0, 0], [[False, True]], [lambda x: -1])
+        assert_array_equal(x, [0, -1])
+
+        x = piecewise([1, 2], [[True, False], [False, True]], [3, 4])
+        assert_array_equal(x, [3, 4])
+
+    def check_default(self):
+        # No value specified for x[1], should be 0
+        x = piecewise([1, 2], [True, False], [2])
+        assert_array_equal(x, [2, 0])
+
+        # Should set x[1] to 3
+        x = piecewise([1, 2], [True, False], [2, 3])
+        assert_array_equal(x, [2, 3])
+
+#    def test_0d(self):
+#        x = array(3)
+#        y = piecewise(x, x>3, [4, 0])
+#        assert y.ndim == 0
+#        assert y == 0
+
+
 def compare_results(res,desired):
     for i in range(len(desired)):
         assert_array_equal(res[i],desired[i])



More information about the Numpy-svn mailing list