[Numpy-svn] r6294 - in trunk/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Sun Jan 4 14:16:03 CST 2009


Author: pierregm
Date: 2009-01-04 14:16:00 -0600 (Sun, 04 Jan 2009)
New Revision: 6294

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
* adapted default_fill_value for flexible datatype
* fixed max/minimum_fill_value for flexible datatype

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2009-01-04 12:03:29 UTC (rev 6293)
+++ trunk/numpy/ma/core.py	2009-01-04 20:16:00 UTC (rev 6294)
@@ -152,7 +152,7 @@
 
     """
     if hasattr(obj,'dtype'):
-        defval = default_filler[obj.dtype.kind]
+        defval = _check_fill_value(None, obj.dtype)
     elif isinstance(obj, np.dtype):
         if obj.subdtype:
             defval = default_filler[obj.subdtype[0].kind]
@@ -170,6 +170,18 @@
         defval = default_filler['O']
     return defval
 
+
+def _recursive_extremum_fill_value(ndtype, extremum):
+    names = ndtype.names
+    if names:
+        deflist = []
+        for name in names:
+            fval = _recursive_extremum_fill_value(ndtype[name], extremum)
+            deflist.append(fval)
+        return tuple(deflist)
+    return extremum[ndtype]
+
+
 def minimum_fill_value(obj):
     """
     Calculate the default fill value suitable for taking the minimum of ``obj``.
@@ -177,11 +189,7 @@
     """
     errmsg = "Unsuitable type for calculating minimum."
     if hasattr(obj, 'dtype'):
-        objtype = obj.dtype
-        filler = min_filler[objtype]
-        if filler is None:
-            raise TypeError(errmsg)
-        return filler
+        return _recursive_extremum_fill_value(obj.dtype, min_filler)
     elif isinstance(obj, float):
         return min_filler[ntypes.typeDict['float_']]
     elif isinstance(obj, int):
@@ -193,6 +201,7 @@
     else:
         raise TypeError(errmsg)
 
+
 def maximum_fill_value(obj):
     """
     Calculate the default fill value suitable for taking the maximum of ``obj``.
@@ -200,11 +209,7 @@
     """
     errmsg = "Unsuitable type for calculating maximum."
     if hasattr(obj, 'dtype'):
-        objtype = obj.dtype
-        filler = max_filler[objtype]
-        if filler is None:
-            raise TypeError(errmsg)
-        return filler
+        return _recursive_extremum_fill_value(obj.dtype, max_filler)
     elif isinstance(obj, float):
         return max_filler[ntypes.typeDict['float_']]
     elif isinstance(obj, int):
@@ -257,7 +262,7 @@
         if fields:
             descr = ndtype.descr
             fill_value = np.array(_recursive_set_default_fill_value(descr),
-                                  dtype=ndtype)
+                                  dtype=ndtype,)
         else:
             fill_value = default_fill_value(ndtype)
     elif fields:

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2009-01-04 12:03:29 UTC (rev 6293)
+++ trunk/numpy/ma/tests/test_core.py	2009-01-04 20:16:00 UTC (rev 6294)
@@ -1074,6 +1074,29 @@
         control = np.array((0,0,0), dtype="int, float, float").astype(ndtype)
         assert_equal(_check_fill_value(0, ndtype), control)
 
+
+    def test_extremum_fill_value(self):
+        "Tests extremum fill values for flexible type."
+        a = array([(1, (2, 3)), (4, (5, 6))],
+                  dtype=[('A', int), ('B', [('BA', int), ('BB', int)])])
+        test = a.fill_value
+        assert_equal(test['A'], default_fill_value(a['A']))
+        assert_equal(test['B']['BA'], default_fill_value(a['B']['BA']))
+        assert_equal(test['B']['BB'], default_fill_value(a['B']['BB']))
+        #
+        test = minimum_fill_value(a)
+        assert_equal(test[0], minimum_fill_value(a['A']))
+        assert_equal(test[1][0], minimum_fill_value(a['B']['BA']))
+        assert_equal(test[1][1], minimum_fill_value(a['B']['BB']))
+        assert_equal(test[1], minimum_fill_value(a['B']))
+        #
+        test = maximum_fill_value(a)
+        assert_equal(test[0], maximum_fill_value(a['A']))
+        assert_equal(test[1][0], maximum_fill_value(a['B']['BA']))
+        assert_equal(test[1][1], maximum_fill_value(a['B']['BB']))
+        assert_equal(test[1], maximum_fill_value(a['B']))
+    
+
 #------------------------------------------------------------------------------
 
 class TestUfuncs(TestCase):
@@ -1820,6 +1843,28 @@
         assert_equal(am, an)
 
 
+    def test_sort_flexible(self):
+        "Test sort on flexible dtype."
+        a = array([(3, 3), (3, 2), (2, 2), (2, 1), (1, 0), (1, 1), (1, 2)],
+             mask=[(0, 0), (0, 1), (0, 0), (0, 0), (1, 0), (0, 0), (0, 0)],
+            dtype=[('A', int), ('B', int)])
+        #
+        test = sort(a)
+        b = array([(1, 1), (1, 2), (2, 1), (2, 2), (3, 3), (3, 2), (1, 0)],
+             mask=[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (1, 0)],
+            dtype=[('A', int), ('B', int)])
+        assert_equal(test, b)
+        assert_equal(test.mask, b.mask)
+        #
+        test = sort(a, endwith=False)
+        b = array([(1, 0), (1, 1), (1, 2), (2, 1), (2, 2), (3, 2), (3, 3),],
+             mask=[(1, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (0, 0),],
+            dtype=[('A', int), ('B', int)])
+        assert_equal(test, b)
+        assert_equal(test.mask, b.mask)
+        #
+
+
     def test_squeeze(self):
         "Check squeeze"
         data = masked_array([[1,2,3]])



More information about the Numpy-svn mailing list