[Scipy-svn] r2773 - trunk/Lib/sandbox/maskedarray

scipy-svn@scip... scipy-svn@scip...
Tue Feb 27 15:41:23 CST 2007


Author: pierregm
Date: 2007-02-27 15:41:19 -0600 (Tue, 27 Feb 2007)
New Revision: 2773

Modified:
   trunk/Lib/sandbox/maskedarray/__init__.py
   trunk/Lib/sandbox/maskedarray/core.py
   trunk/Lib/sandbox/maskedarray/extras.py
   trunk/Lib/sandbox/maskedarray/mrecords.py
   trunk/Lib/sandbox/maskedarray/testutils.py
Log:
testutils : fixed a pb in assert_equal (a bad indentation prevented assert_array_equal to be called)
core      : masked_where now returns a subclass of MaskedArray when needed

Modified: trunk/Lib/sandbox/maskedarray/__init__.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/__init__.py	2007-02-27 15:09:20 UTC (rev 2772)
+++ trunk/Lib/sandbox/maskedarray/__init__.py	2007-02-27 21:41:19 UTC (rev 2773)
@@ -12,11 +12,9 @@
 __date__     = '$Date$'
 
 import core
-#reload(core)
 from core import *
 
 import extras
-#reload(extras)
 from extras import *
 
 

Modified: trunk/Lib/sandbox/maskedarray/core.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/core.py	2007-02-27 15:09:20 UTC (rev 2772)
+++ trunk/Lib/sandbox/maskedarray/core.py	2007-02-27 21:41:19 UTC (rev 2773)
@@ -372,6 +372,8 @@
             return masked
         d1 = filled(a, self.fillx)
         d2 = filled(b, self.filly)
+# CHECK : Do we really need to fill the arguments ? Pro'ly not        
+#        result = self.f(a, b, *args, **kwargs).view(get_masked_subclass(a,b))
         result = self.f(d1, d2, *args, **kwargs).view(get_masked_subclass(a,b))
         if result.ndim > 0:
             result._mask = m
@@ -643,7 +645,7 @@
 #####--------------------------------------------------------------------------
 #--- --- Masking functions ---
 #####--------------------------------------------------------------------------
-def masked_where(condition, x, copy=True):
+def masked_where(condition, a, copy=True):
     """Returns `x` as an array masked where `condition` is true.
 Masked values of `x` or `condition` are kept.
 
@@ -652,12 +654,16 @@
     - `x` (ndarray) : Array to mask.
     - `copy` (boolean, *[False]*) : Returns a copy of `m` if true.
     """
-    cm = filled(condition,1)
-    if isinstance(x,MaskedArray):
-        m = mask_or(x._mask, cm)
-        return x.__class__(x._data, mask=m, copy=copy)
+    cond = filled(condition,1)
+    a = numeric.array(a, copy=copy, subok=True)
+    if hasattr(a, '_mask'):
+        cond = mask_or(cond, a._mask)
+        cls = type(a)
     else:
-        return MaskedArray(fromnumeric.asarray(x), copy=copy, mask=cm)
+        cls = MaskedArray
+    result = a.view(cls)
+    result._mask = cond
+    return result
 
 def masked_greater(x, value, copy=1):
     "Shortcut to `masked_where`, with ``condition = (x > value)``."

Modified: trunk/Lib/sandbox/maskedarray/extras.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/extras.py	2007-02-27 15:09:20 UTC (rev 2772)
+++ trunk/Lib/sandbox/maskedarray/extras.py	2007-02-27 21:41:19 UTC (rev 2773)
@@ -25,7 +25,6 @@
 from itertools import groupby
 
 import core
-#reload(core)
 from core import *
 
 import numpy

Modified: trunk/Lib/sandbox/maskedarray/mrecords.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/mrecords.py	2007-02-27 15:09:20 UTC (rev 2772)
+++ trunk/Lib/sandbox/maskedarray/mrecords.py	2007-02-27 21:41:19 UTC (rev 2773)
@@ -28,7 +28,6 @@
 _typestr = ntypes._typestr
 
 import maskedarray as MA
-#reload(MA)
 from maskedarray import masked, nomask, mask_or, filled, getmask, getmaskarray, \
     masked_array, make_mask
 from maskedarray import MaskedArray

Modified: trunk/Lib/sandbox/maskedarray/testutils.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/testutils.py	2007-02-27 15:09:20 UTC (rev 2772)
+++ trunk/Lib/sandbox/maskedarray/testutils.py	2007-02-27 21:41:19 UTC (rev 2773)
@@ -73,7 +73,7 @@
         return _assert_equal_on_sequences(actual.tolist(), 
                                           desired.tolist(), 
                                           err_msg='')
-        return assert_array_equal(actual, desired, err_msg)
+    return assert_array_equal(actual, desired, err_msg)
 #.............................
 def fail_if_equal(actual,desired,err_msg='',):
     """Raises an assertion error if two items are equal.



More information about the Scipy-svn mailing list