[Numpy-svn] r5021 - in trunk/numpy: core core/tests lib lib/tests

numpy-svn@scip... numpy-svn@scip...
Fri Apr 11 01:53:52 CDT 2008


Author: oliphant
Date: 2008-04-11 01:53:49 -0500 (Fri, 11 Apr 2008)
New Revision: 5021

Modified:
   trunk/numpy/core/numerictypes.py
   trunk/numpy/core/tests/test_numerictypes.py
   trunk/numpy/lib/index_tricks.py
   trunk/numpy/lib/tests/test_index_tricks.py
Log:
Fixed #728 scalar coercion problem with mixed types and r_

Modified: trunk/numpy/core/numerictypes.py
===================================================================
--- trunk/numpy/core/numerictypes.py	2008-04-11 06:34:20 UTC (rev 5020)
+++ trunk/numpy/core/numerictypes.py	2008-04-11 06:53:49 UTC (rev 5021)
@@ -78,7 +78,7 @@
 # we add more at the bottom
 __all__ = ['sctypeDict', 'sctypeNA', 'typeDict', 'typeNA', 'sctypes',
            'ScalarType', 'obj2sctype', 'cast', 'nbytes', 'sctype2char',
-           'maximum_sctype', 'issctype', 'typecodes']
+           'maximum_sctype', 'issctype', 'typecodes', 'find_common_type']
 
 from numpy.core.multiarray import typeinfo, ndarray, array, empty, dtype
 import types as _types
@@ -566,7 +566,7 @@
 
 del key
 
-typecodes = {'Character':'S1',
+typecodes = {'Character':'c',
              'Integer':'bhilqp',
              'UnsignedInteger':'BHILQP',
              'Float':'fdg',
@@ -578,3 +578,69 @@
 # backwards compatibility --- deprecated name
 typeDict = sctypeDict
 typeNA = sctypeNA
+
+_kind_list = ['b', 'u', 'i', 'f', 'c', 'S', 'U', 'V', 'O']
+
+__test_types = typecodes['AllInteger'][:-2]+typecodes['AllFloat']+'O'
+__len_test_types = len(__test_types)
+
+# Keep incrementing until a common type both can be coerced to
+#  is found.  Otherwise, return None
+def _find_common_coerce(a, b):
+    if a > b:
+        return a
+    try:
+        thisind = __test_types.index(a.char)
+    except ValueError:
+        return None
+    while thisind < __len_test_types:
+        newdtype = dtype(__test_types[thisind])
+        if newdtype >= b and newdtype >= a:
+            return newdtype
+        thisind += 1
+    return None
+    
+
+def find_common_type(array_types, scalar_types):
+    """Determine common type following standard coercion rules
+
+    Parameters
+    ----------
+    array_types : sequence
+        A list of dtype convertible objects representing arrays
+    scalar_types : sequence
+        A list of dtype convertible objects representing scalars
+        
+    Returns
+    -------
+    datatype : dtype
+        The common data-type which is the maximum of the array_types
+        ignoring the scalar_types unless the maximum of the scalar_types
+        is of a different kind. 
+
+        If the kinds is not understood, then None is returned.
+    """
+    array_types = [dtype(x) for x in array_types]
+    scalar_types = [dtype(x) for x in scalar_types]
+
+    if len(scalar_types) == 0:
+        if len(array_types) == 0:
+            return None
+        else:
+            return max(array_types)
+    if len(array_types) == 0:
+        return max(scalar_types)
+
+    maxa = max(array_types)
+    maxsc = max(scalar_types)
+
+    try:
+        index_a = _kind_list.index(maxa.kind)
+        index_sc = _kind_list.index(maxsc.kind)
+    except ValueError:
+        return None
+    
+    if index_sc > index_a:
+        return _find_common_coerce(maxsc,maxa)
+    else:
+        return maxa

Modified: trunk/numpy/core/tests/test_numerictypes.py
===================================================================
--- trunk/numpy/core/tests/test_numerictypes.py	2008-04-11 06:34:20 UTC (rev 5020)
+++ trunk/numpy/core/tests/test_numerictypes.py	2008-04-11 06:53:49 UTC (rev 5021)
@@ -338,5 +338,27 @@
         assert(a['int'].shape == (5,0))
         assert(a['float'].shape == (5,2))
 
+class TestCommonType(NumpyTestCase):
+    def check_scalar_loses1(self):
+        res = numpy.find_common_type(['f4','f4','i4'],['f8'])
+        assert(res == 'f4')
+    def check_scalar_loses2(self):
+        res = numpy.find_common_type(['f4','f4'],['i8'])
+        assert(res == 'f4')
+    def check_scalar_wins(self):
+        res = numpy.find_common_type(['f4','f4','i4'],['c8'])
+        assert(res == 'c8')
+    def check_scalar_wins2(self):
+        res = numpy.find_common_type(['u4','i4','i4'],['f4'])
+        assert(res == 'f8')
+    def check_scalar_wins3(self): # doesn't go up to 'f16' on purpose
+        res = numpy.find_common_type(['u8','i8','i8'],['f8'])
+        assert(res == 'f8')
+
+        
+
+        
+        
+
 if __name__ == "__main__":
     NumpyTest().run()

Modified: trunk/numpy/lib/index_tricks.py
===================================================================
--- trunk/numpy/lib/index_tricks.py	2008-04-11 06:34:20 UTC (rev 5020)
+++ trunk/numpy/lib/index_tricks.py	2008-04-11 06:53:49 UTC (rev 5021)
@@ -7,7 +7,8 @@
 
 import sys
 import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, ScalarType, array
+from numpy.core.numeric import asarray, ScalarType, array, dtype
+from numpy.core.numerictypes import find_common_type
 import math
 
 import function_base
@@ -225,7 +226,8 @@
             key = (key,)
         objs = []
         scalars = []
-        final_dtypedescr = None
+        arraytypes = []
+        scalartypes = []
         for k in range(len(key)):
             scalar = False
             if type(key[k]) is slice:
@@ -272,6 +274,7 @@
                 newobj = array(key[k],ndmin=ndmin)
                 scalars.append(k)
                 scalar = True
+                scalartypes.append(newobj.dtype)
             else:
                 newobj = key[k]
                 if ndmin > 1:
@@ -289,14 +292,15 @@
                         newobj = newobj.transpose(axes)
                     del tempobj
             objs.append(newobj)
-            if isinstance(newobj, _nx.ndarray) and not scalar:
-                if final_dtypedescr is None:
-                    final_dtypedescr = newobj.dtype
-                elif newobj.dtype > final_dtypedescr:
-                    final_dtypedescr = newobj.dtype
-        if final_dtypedescr is not None:
+            if not scalar and isinstance(newobj, _nx.ndarray):
+                arraytypes.append(newobj.dtype)
+                
+        #  Esure that scalars won't up-cast unless warranted
+        final_dtype = find_common_type(arraytypes, scalartypes)
+        if final_dtype is not None:
             for k in scalars:
-                objs[k] = objs[k].astype(final_dtypedescr)
+                objs[k] = objs[k].astype(final_dtype)
+
         res = _nx.concatenate(tuple(objs),axis=self.axis)
         return self._retval(res)
 

Modified: trunk/numpy/lib/tests/test_index_tricks.py
===================================================================
--- trunk/numpy/lib/tests/test_index_tricks.py	2008-04-11 06:34:20 UTC (rev 5020)
+++ trunk/numpy/lib/tests/test_index_tricks.py	2008-04-11 06:53:49 UTC (rev 5021)
@@ -35,6 +35,10 @@
         c = r_[b,0,0,b]
         assert_array_equal(c,[1,1,1,1,1,0,0,1,1,1,1,1])
 
+    def check_mixed_type(self):
+        g = r_[10.1, 1:10]
+        assert(g.dtype == 'f8')
+
     def check_2d(self):
         b = rand(5,5)
         c = rand(5,5)



More information about the Numpy-svn mailing list