[Scipy-svn] r4700 - in trunk/scipy/stats: . tests

scipy-svn@scip... scipy-svn@scip...
Mon Sep 8 03:05:25 CDT 2008


Author: pierregm
Date: 2008-09-08 03:05:22 -0500 (Mon, 08 Sep 2008)
New Revision: 4700

Modified:
   trunk/scipy/stats/mstats.py
   trunk/scipy/stats/tests/test_mstats.py
Log:
* force compatibility between mstats.mode and stats.mode

Modified: trunk/scipy/stats/mstats.py
===================================================================
--- trunk/scipy/stats/mstats.py	2008-09-08 07:24:31 UTC (rev 4699)
+++ trunk/scipy/stats/mstats.py	2008-09-08 08:05:22 UTC (rev 4700)
@@ -256,16 +256,25 @@
     def _mode1D(a):
         (rep,cnt) = find_repeats(a)
         if not cnt.ndim:
-            return (0,0)
+            return (0, 0)
         elif cnt.size:
             return (rep[cnt.argmax()], cnt.max())
         return (a[0], 1)
     #
     if axis is None:
         output = _mode1D(ma.ravel(a))
+        output = (ma.array(output[0]), ma.array(output[1]))
     else:
         output = ma.apply_along_axis(_mode1D, axis, a)
-    return tuple(output)
+        newshape = list(a.shape)
+        newshape[axis] = 1
+        slices = [slice(None)] * output.ndim
+        slices[axis] = 0
+        modes = output[tuple(slices)].reshape(newshape)
+        slices[axis] = 1
+        counts = output[tuple(slices)].reshape(newshape)
+        output = (modes, counts)
+    return output
 mode.__doc__ = stats.mode.__doc__
 
 

Modified: trunk/scipy/stats/tests/test_mstats.py
===================================================================
--- trunk/scipy/stats/tests/test_mstats.py	2008-09-08 07:24:31 UTC (rev 4699)
+++ trunk/scipy/stats/tests/test_mstats.py	2008-09-08 08:05:22 UTC (rev 4700)
@@ -348,10 +348,10 @@
         assert_equal(mstats.mode(ma1, axis=None), (0,3))
         assert_equal(mstats.mode(a2, axis=None), (3,4))
         assert_equal(mstats.mode(ma2, axis=None), (0,3))
-        assert_equal(mstats.mode(a2, axis=0), [[0,0,0,1,1],[1,1,1,1,1]])
-        assert_equal(mstats.mode(ma2, axis=0), [[0,0,0,1,1],[1,1,1,1,1]])
-        assert_equal(mstats.mode(a2, axis=-1), [[0,3],[3,3],[3,1]])
-        assert_equal(mstats.mode(ma2, axis=-1), [[0,3],[1,1],[0,0]])
+        assert_equal(mstats.mode(a2, axis=0), ([[0,0,0,1,1]],[[1,1,1,1,1]]))
+        assert_equal(mstats.mode(ma2, axis=0), ([[0,0,0,1,1]],[[1,1,1,1,1]]))
+        assert_equal(mstats.mode(a2, axis=-1), ([[0],[3],[3]], [[3],[3],[1]]))
+        assert_equal(mstats.mode(ma2, axis=-1), ([[0],[1],[0]], [[3],[1],[0]]))
 
 
 class TestPercentile(TestCase):



More information about the Scipy-svn mailing list