[Numpy-svn] r5423 - in branches/1.1.x/numpy/ma: . tests

numpy-svn@scip... numpy-svn@scip...
Tue Jul 15 16:49:50 CDT 2008


Author: pierregm
Date: 2008-07-15 16:49:36 -0500 (Tue, 15 Jul 2008)
New Revision: 5423

Modified:
   branches/1.1.x/numpy/ma/core.py
   branches/1.1.x/numpy/ma/mrecords.py
   branches/1.1.x/numpy/ma/tests/test_core.py
   branches/1.1.x/numpy/ma/tests/test_mrecords.py
Log:
* improved support for flexible dtypes (w/ nesting and shaped fields)

Modified: branches/1.1.x/numpy/ma/core.py
===================================================================
--- branches/1.1.x/numpy/ma/core.py	2008-07-15 17:42:50 UTC (rev 5422)
+++ branches/1.1.x/numpy/ma/core.py	2008-07-15 21:49:36 UTC (rev 5423)
@@ -125,7 +125,10 @@
     if hasattr(obj,'dtype'):
         defval = default_filler[obj.dtype.kind]
     elif isinstance(obj, np.dtype):
-        defval = default_filler[obj.kind]
+        if obj.subdtype:
+            defval = default_filler[obj.subdtype[0].kind]
+        else:
+            defval = default_filler[obj.kind]
     elif isinstance(obj, float):
         defval = default_filler['f']
     elif isinstance(obj, int) or isinstance(obj, long):
@@ -184,19 +187,28 @@
 
 
 def _check_fill_value(fill_value, dtype):
-    descr = np.dtype(dtype).descr
+    ndtype = np.dtype(dtype)
+    fields = ndtype.fields
     if fill_value is None:
-        if len(descr) > 1:
-            fill_value = [default_fill_value(np.dtype(d[1]))
-                          for d in descr]
+        if fields:
+            fill_value = [default_fill_value(fields[n][0])
+                          for n in ndtype.names]
         else:
-            fill_value = default_fill_value(dtype)
+            fill_value = default_fill_value(ndtype)
     else:
         fill_value = np.array(fill_value).tolist()
-        fval = np.resize(fill_value, len(descr))
-        if len(descr) > 1:
-            fill_value = [np.asarray(f).astype(d[1]).item()
-                          for (f,d) in zip(fval, descr)]
+        fval = np.resize(fill_value, len(ndtype.descr))
+        if fields:
+            fill_value = []
+            for (f, n) in zip(fval, ndtype.names):
+                current = fields[n][0]
+                if current.subdtype:
+                    fill_value.append(np.asarray(f).astype(current.subdtype[0]))
+                else:
+                    fill_value.append(np.asarray(f).astype(current))
+#            
+#            fill_value = [np.asarray(f).astype(fields[n][0]).item()
+#                          for (f, n) in zip(fval, ndtype.names)]
         else:
             fill_value = np.array(fval, copy=False, dtype=dtype).item()
     return fill_value
@@ -1147,7 +1159,9 @@
         Value used to fill in the masked values when necessary. If
         None, a default based on the datatype is used.
     keep_mask : {True, boolean}
-        Whether to combine mask with the mask of the input data,
+        Whether to combine m
+        x = mrecarray(1, formats="(2,2)f8")
+        assert_equal(x.fill_value, ma.default_fill_value(np.dtype(float)))ask with the mask of the input data,
         if any (True), or to use only mask for the output (False).
     hard_mask : {False, boolean}
         Whether to use a hard mask or not. With a hard mask,
@@ -3750,3 +3764,5 @@
 zeros = _convert2ma('zeros')
 
 ###############################################################################
+if __name__ == '__main__':
+    x = array((1,), dtype=[('f0', '<f8', (2, 2))])

Modified: branches/1.1.x/numpy/ma/mrecords.py
===================================================================
--- branches/1.1.x/numpy/ma/mrecords.py	2008-07-15 17:42:50 UTC (rev 5422)
+++ branches/1.1.x/numpy/ma/mrecords.py	2008-07-15 21:49:36 UTC (rev 5423)
@@ -88,11 +88,13 @@
     return np.dtype(ndescr)
 
 
-def _get_fieldmask(self):
-    mdescr = [(n,'|b1') for n in self.dtype.names]
-    fdmask = np.empty(self.shape, dtype=mdescr)
-    fdmask.flat = tuple([False]*len(mdescr))
-    return fdmask
+def _make_mask_dtype(ndtype):
+    mdescr = []
+    for descr in ndtype.descr:
+        current = list(descr)
+        current[1] = '|b1'
+        mdescr.append(tuple(current))
+    return mdescr
 
 
 class MaskedRecords(MaskedArray, object):
@@ -119,10 +121,11 @@
                 **options):
         #
         self = recarray.__new__(cls, shape, dtype=dtype, buf=buf, offset=offset,
-                                strides=strides, formats=formats,
-                                byteorder=byteorder, aligned=aligned,)
+                                strides=strides, formats=formats, names=names,
+                                titles=titles, byteorder=byteorder,
+                                aligned=aligned,)
         #
-        mdtype = [(k,'|b1') for (k,_) in self.dtype.descr]
+        mdtype = _make_mask_dtype(self.dtype)
         if mask is nomask or not np.size(mask):
             if not keep_mask:
                 self._fieldmask = tuple([False]*len(mdtype))
@@ -155,7 +158,7 @@
         # Make sure we have a _fieldmask by default ..
         _fieldmask = getattr(obj, '_fieldmask', None)
         if _fieldmask is None:
-            mdescr = [(n,'|b1') for (n,_) in self.dtype.descr]
+            mdescr = _make_mask_dtype(ndarray.__getattribute__(self, 'dtype'))
             _mask = getattr(obj, '_mask', nomask)
             if _mask is nomask:
                 _fieldmask = np.empty(self.shape, dtype=mdescr).view(recarray)
@@ -187,8 +190,8 @@
     #......................................................
     def __setmask__(self, mask):
         "Sets the mask and update the fieldmask."
-        names = self.dtype.names
         fmask = self.__dict__['_fieldmask']
+        names = fmask.dtype.names
         #
         if isinstance(mask,ndarray) and mask.dtype.names == names:
             for n in names:
@@ -202,7 +205,12 @@
                     fmask[n].__ior__(newmask)
             else:
                 for n in names:
-                    fmask[n].flat = newmask
+                    current = fmask[n]
+                    if current.shape == newmask.shape or newmask.size == 1:
+                        current.flat = newmask
+                    else:
+                        for (i,n) in enumerate(newmask):
+                            current[i] = n
         return
     _setmask = __setmask__
     #
@@ -211,10 +219,16 @@
     A record is masked when all the fields are masked.
 
         """
-        if self.size > 1:
-            return self._fieldmask.view((bool_, len(self.dtype))).all(1)
+        fieldmask = ndarray.__getattribute__(self, '_fieldmask')
+        if fieldmask.size > 1:
+            axis = 1
         else:
-            return self._fieldmask.view((bool_, len(self.dtype))).all()
+            axis=None
+        try:
+            return fieldmask.view((bool_, len(self.dtype))).all(axis)
+        except ValueError:
+            return np.all([[f[n].all() for n in fieldmask.dtype.names]
+                           for f in fieldmask], axis=axis)
     mask = _mask = property(fget=_getmask, fset=_setmask)
     #......................................................
     def get_fill_value(self):
@@ -224,7 +238,11 @@
         if self._fill_value is None:
             ddtype = self.dtype
             fillval = _check_fill_value(None, ddtype)
-            self._fill_value = np.array(tuple(fillval), dtype=ddtype)
+            # We can't use ddtype to reconstruct the array as we don't need...
+            # ... the shape of the fields
+            self._fill_value = np.array(tuple(fillval),
+                                        dtype=zip(ddtype.names, 
+                                                  (_[1] for _ in ddtype.descr)))
         return self._fill_value
 
     def set_fill_value(self, value=None):
@@ -421,7 +439,7 @@
 
         """
         _localdict = self.__dict__
-        d = self._data
+        d = ndarray.__getattribute__(self, '_data')
         fm = _localdict['_fieldmask']
         if not np.asarray(fm, dtype=bool_).any():
             return d
@@ -437,7 +455,7 @@
             result = np.asanyarray(value)
         else:
             result = d.copy()
-            for (n, v) in zip(d.dtype.names, value):
+            for (n, v) in zip(fm.dtype.names, value):
                 np.putmask(np.asarray(result[n]), np.asarray(fm[n]), v)
         return result
     #......................................................
@@ -776,3 +794,15 @@
     return newdata
 
 ###############################################################################
+if __name__ == '__main__':
+    from numpy.ma.testutils import assert_equal
+    
+    if 1:
+        ilist = [1,2,3,4,5]
+        flist = [1.1,2.2,3.3,4.4,5.5]
+        slist = ['one','two','three','four','five']
+        ddtype = [('a',int),('b',float),('c','|S8')]
+        mask = [0,1,0,0,1]
+        base = ma.array(zip(ilist,flist,slist), mask=mask, dtype=ddtype)
+        mbase = base.view(mrecarray)
+        mbase._mask = nomask
\ No newline at end of file

Modified: branches/1.1.x/numpy/ma/tests/test_core.py
===================================================================
--- branches/1.1.x/numpy/ma/tests/test_core.py	2008-07-15 17:42:50 UTC (rev 5422)
+++ branches/1.1.x/numpy/ma/tests/test_core.py	2008-07-15 21:49:36 UTC (rev 5423)
@@ -792,7 +792,7 @@
         series = data[[0,2,1]]
         assert_equal(series._fill_value, data._fill_value)
         #
-        mtype = [('f',float_),('s','|S3')]
+        mtype = [('f',float),('s','|S3')]
         x = array([(1,'a'),(2,'b'),(numpy.pi,'pi')], dtype=mtype)
         x.fill_value=999
         assert_equal(x.fill_value,[999.,'999'])

Modified: branches/1.1.x/numpy/ma/tests/test_mrecords.py
===================================================================
--- branches/1.1.x/numpy/ma/tests/test_mrecords.py	2008-07-15 17:42:50 UTC (rev 5422)
+++ branches/1.1.x/numpy/ma/tests/test_mrecords.py	2008-07-15 21:49:36 UTC (rev 5423)
@@ -270,6 +270,31 @@
         #
         assert_equal(mrec.tolist(),
                      [(1,1.1,None),(2,2.2,'two'),(None,None,'three')])
+    #
+    def test_withnames(self):
+        "Test the creation w/ format and names"
+        x = mrecarray(1, formats=float, names='base')
+        x[0]['base'] = 10
+        assert_equal(x['base'][0], 10)
+    #
+    def test_exotic_formats(self):
+        "Test that 'exotic' formats are processed properly"
+        easy = mrecarray(1, dtype=[('i',int), ('s','|S3'), ('f',float)])
+        easy[0] = masked
+        easy.filled(1)
+        assert_equal(easy.filled(1).item(), (1,'1',1.))
+        #
+        solo = mrecarray(1, dtype=[('f0', '<f8', (2, 2))])
+        solo[0] = masked
+        assert_equal(solo.filled(1).item(), 
+                     np.array((1,), dtype=solo.dtype).item())
+        #
+        mult = mrecarray(2, dtype= "i4, (2,3)float, float")
+        mult[0] = masked
+        mult[1] = (1, 1, 1)
+        mult.filled(0)
+        assert_equal(mult.filled(0),
+                     np.array([(0,0,0),(1,1,1)], dtype=mult.dtype))
 
 ################################################################################
 class TestMRecordsImport(NumpyTestCase):



More information about the Numpy-svn mailing list