[Numpy-svn] r3398 - trunk/numpy/lib

numpy-svn at scipy.org numpy-svn at scipy.org
Wed Oct 25 06:37:33 CDT 2006


Author: oliphant
Date: 2006-10-25 06:37:30 -0500 (Wed, 25 Oct 2006)
New Revision: 3398

Modified:
   trunk/numpy/lib/function_base.py
Log:
Fix vectorize bug ignoring otypes.

Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2006-10-25 10:41:11 UTC (rev 3397)
+++ trunk/numpy/lib/function_base.py	2006-10-25 11:37:30 UTC (rev 3398)
@@ -764,11 +764,17 @@
   Description:
 
     Define a vectorized function which takes nested sequence
-    objects or numpy arrays as inputs and returns a
+    of objects or numpy arrays as inputs and returns a
     numpy array as output, evaluating the function over successive
     tuples of the input arrays like the python map function except it uses
     the broadcasting rules of numpy.
 
+    Data-type of output of vectorized is determined by calling the function
+    with the first element of the input.  This can be avoided by specifying
+    the otypes argument as either a string of typecode characters or a list
+    of data-types specifiers.  There should be one data-type specifier for
+    each output. 
+
   Input:
 
     somefunction -- a Python function or method
@@ -804,11 +810,13 @@
             self.__doc__ = doc
         if isinstance(otypes, types.StringType):
             self.otypes = otypes
+            for char in self.otypes:
+                if char not in typecodes['All']:
+                    raise ValueError, "invalid otype specified"
+        elif iterable(otypes):
+            self.otypes = ''.join([_nx.dtype(x).char for x in otypes])
         else:
-            raise ValueError, "output types must be a string"
-        for char in self.otypes:
-            if char not in typecodes['All']:
-                raise ValueError, "invalid typecode specified"
+            raise ValueError, "output types must be a string of typecode characters or a list of data-types"
         self.lastcallargs = 0
 
     def __call__(self, *args):
@@ -835,10 +843,11 @@
             else:
                 self.nout = 1
                 theout = (theout,)
-            otypes = []
-            for k in range(self.nout):
-                otypes.append(asarray(theout[k]).dtype.char)
-            self.otypes = ''.join(otypes)
+            if self.otypes == '':
+                otypes = []
+                for k in range(self.nout):
+                    otypes.append(asarray(theout[k]).dtype.char)
+                self.otypes = ''.join(otypes)
 
         if (self.ufunc is None):
             self.ufunc = frompyfunc(self.thefunc, nargs, self.nout)



More information about the Numpy-svn mailing list