[Numpy-svn] r5360 - in trunk/numpy/lib: . tests

numpy-svn@scip... numpy-svn@scip...
Tue Jul 8 03:25:03 CDT 2008


Author: stefan
Date: 2008-07-08 03:24:37 -0500 (Tue, 08 Jul 2008)
New Revision: 5360

Modified:
   trunk/numpy/lib/function_base.py
   trunk/numpy/lib/tests/test_function_base.py
Log:
Piecewise should not expose raw memory.  Closes #798.


Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2008-07-08 07:49:25 UTC (rev 5359)
+++ trunk/numpy/lib/function_base.py	2008-07-08 08:24:37 UTC (rev 5360)
@@ -563,8 +563,11 @@
     """
     x = asanyarray(x)
     n2 = len(funclist)
-    if not isinstance(condlist, type([])):
+    if isscalar(condlist) or \
+           not (isinstance(condlist[0], list) or
+                isinstance(condlist[0], ndarray)):
         condlist = [condlist]
+    condlist = [asarray(c, dtype=bool) for c in condlist]
     n = len(condlist)
     if n == n2-1:  # compute the "otherwise" condition.
         totlist = condlist[0]
@@ -573,10 +576,11 @@
         condlist.append(~totlist)
         n += 1
     if (n != n2):
-        raise ValueError, "function list and condition list must be the same"
+        raise ValueError, "function list and condition list " \
+                          "must be the same"
     zerod = False
-    # This is a hack to work around problems with NumPy's 
-    #  handling of 0-d arrays and boolean indexing with 
+    # This is a hack to work around problems with NumPy's
+    #  handling of 0-d arrays and boolean indexing with
     #  numpy.bool_ scalars
     if x.ndim == 0:
         x = x[None]
@@ -589,7 +593,8 @@
                 condition = condlist[k]
             newcondlist.append(condition)
         condlist = newcondlist
-    y = empty(x.shape, x.dtype)
+
+    y = zeros(x.shape, x.dtype)
     for k in range(n):
         item = funclist[k]
         if not callable(item):
@@ -1090,7 +1095,7 @@
             self.__doc__ = pyfunc.__doc__
         else:
             self.__doc__ = doc
-        if isinstance(otypes, types.StringType):
+        if isinstance(otypes, str):
             self.otypes = otypes
             for char in self.otypes:
                 if char not in typecodes['All']:
@@ -1121,7 +1126,7 @@
             for arg in args:
                 newargs.append(asarray(arg).flat[0])
             theout = self.thefunc(*newargs)
-            if isinstance(theout, types.TupleType):
+            if isinstance(theout, tuple):
                 self.nout = len(theout)
             else:
                 self.nout = 1

Modified: trunk/numpy/lib/tests/test_function_base.py
===================================================================
--- trunk/numpy/lib/tests/test_function_base.py	2008-07-08 07:49:25 UTC (rev 5359)
+++ trunk/numpy/lib/tests/test_function_base.py	2008-07-08 08:24:37 UTC (rev 5360)
@@ -615,16 +615,54 @@
         x = array([5+6j, 1+1j, 1+10j, 10, 5+6j])
         assert(all(unique(x) == [1+1j, 1+10j, 5+6j, 10]))
 
-def compare_results(res,desired):
-    for i in range(len(desired)):
-        assert_array_equal(res[i],desired[i])
 
-class TestPiecewise(TestCase):
+class TestPiecewise(NumpyTestCase):
+    def check_simple(self):
+        # Condition is single bool list
+        x = piecewise([0, 0], [True, False], [1])
+        assert_array_equal(x, [1, 0])
+
+        # List of conditions: single bool list
+        x = piecewise([0, 0], [[True, False]], [1])
+        assert_array_equal(x, [1, 0])
+
+        # Conditions is single bool array
+        x = piecewise([0, 0], array([True, False]), [1])
+        assert_array_equal(x, [1, 0])
+
+        # Condition is single int array
+        x = piecewise([0, 0], array([1, 0]), [1])
+        assert_array_equal(x, [1, 0])
+
+        # List of conditions: int array
+        x = piecewise([0, 0], [array([1, 0])], [1])
+        assert_array_equal(x, [1, 0])
+
+
+        x = piecewise([0, 0], [[False, True]], [lambda x: -1])
+        assert_array_equal(x, [0, -1])
+
+        x = piecewise([1, 2], [[True, False], [False, True]], [3, 4])
+        assert_array_equal(x, [3, 4])
+
+    def check_default(self):
+        # No value specified for x[1], should be 0
+        x = piecewise([1, 2], [True, False], [2])
+        assert_array_equal(x, [2, 0])
+
+        # Should set x[1] to 3
+        x = piecewise([1, 2], [True, False], [2, 3])
+        assert_array_equal(x, [2, 3])
+
     def test_0d(self):
         x = array(3)
         y = piecewise(x, x>3, [4, 0])
         assert y.ndim == 0
         assert y == 0
 
+def compare_results(res,desired):
+    for i in range(len(desired)):
+        assert_array_equal(res[i],desired[i])
+
 if __name__ == "__main__":
     run_module_suite()



More information about the Numpy-svn mailing list