# [Scipy-svn] r3194 - in trunk/Lib/stats: . tests

scipy-svn@scip... scipy-svn@scip...
Thu Jul 26 04:57:33 CDT 2007

```Author: cdavid
Date: 2007-07-26 04:56:18 -0500 (Thu, 26 Jul 2007)
New Revision: 3194

Modified:
trunk/Lib/stats/stats.py
trunk/Lib/stats/tests/test_stats.py
Log:
Correct nanstd and nanmedian + Add testsuite for nan related statistics. Should close #337.

Modified: trunk/Lib/stats/stats.py
===================================================================
--- trunk/Lib/stats/stats.py	2007-07-26 07:27:18 UTC (rev 3193)
+++ trunk/Lib/stats/stats.py	2007-07-26 09:56:18 UTC (rev 3194)
@@ -258,22 +258,36 @@
x, axis = _chk_asarray(x,axis)
x = x.copy()
Norig = x.shape[axis]
-    n = Norig - np.sum(np.isnan(x),axis)*1.0
-    factor = n/Norig

-    x[np.isnan(x)] = 0
-    m1 = np.mean(x,axis)
-    m1c = m1/factor
-    m2 = np.mean((x-m1c)**2.0,axis)
+    Nnan = np.sum(np.isnan(x),axis)*1.0
+    n = Norig - Nnan
+
+    x[np.isnan(x)] = 0.
+    m1 = np.sum(x,axis)/n
+
+    # Kludge to subtract m1 from the correct axis
+    if axis!=0:
+        shape = np.arange(x.ndim).tolist()
+        shape.remove(axis)
+        shape.insert(0,axis)
+        x = x.transpose(tuple(shape))
+        d = (x-m1)**2.0
+        shape = tuple(array(shape).argsort())
+        d = d.transpose(shape)
+    else:
+        d = (x-m1)**2.0
+    m2 = np.sum(d,axis)-(m1*m1)*Nnan
if bias:
-        m2c = m2/factor
+        m2c = m2 / n
else:
-        m2c = m2*Norig/(n-1.0)
-    return m2c
+        m2c = m2 / (n - 1.)
+    return np.sqrt(m2c)

def _nanmedian(arr1d):  # This only works on 1d arrays
cond = 1-np.isnan(arr1d)
x = np.sort(np.compress(cond,arr1d,axis=-1))
+    if x.size == 0:
+        return np.nan
return median(x)

def nanmedian(x, axis=0):

Modified: trunk/Lib/stats/tests/test_stats.py
===================================================================
--- trunk/Lib/stats/tests/test_stats.py	2007-07-26 07:27:18 UTC (rev 3193)
+++ trunk/Lib/stats/tests/test_stats.py	2007-07-26 09:56:18 UTC (rev 3194)
@@ -184,6 +184,64 @@
y = scipy.stats.std(ROUND)
assert_approx_equal(y, 2.738612788)

+class test_nanfunc(NumpyTestCase):
+    def __init__(self, *args, **kw):
+        NumpyTestCase.__init__(self, *args, **kw)
+        self.X = X.copy()
+
+        self.Xall = X.copy()
+        self.Xall[:] = numpy.nan
+
+        self.Xsome = X.copy()
+        self.Xsomet = X.copy()
+        self.Xsome[0] = numpy.nan
+        self.Xsomet = self.Xsomet[1:]
+
+    def check_nanmean_none(self):
+        """Check nanmean when no values are nan."""
+        m = stats.stats.nanmean(X)
+        assert_approx_equal(m, X[4])
+
+    def check_nanmean_some(self):
+        """Check nanmean when some values only are nan."""
+        m = stats.stats.nanmean(self.Xsome)
+        assert_approx_equal(m, 5.5)
+
+    def check_nanmean_all(self):
+        """Check nanmean when all values are nan."""
+        m = stats.stats.nanmean(self.Xall)
+        assert numpy.isnan(m)
+
+    def check_nanstd_none(self):
+        """Check nanstd when no values are nan."""
+        s = stats.stats.nanstd(self.X)
+        assert_approx_equal(s, stats.stats.std(self.X))
+
+    def check_nanstd_some(self):
+        """Check nanstd when some values only are nan."""
+        s = stats.stats.nanstd(self.Xsome)
+        assert_approx_equal(s, stats.stats.std(self.Xsomet))
+
+    def check_nanstd_all(self):
+        """Check nanstd when all values are nan."""
+        s = stats.stats.nanstd(self.Xall)
+        assert numpy.isnan(s)
+
+    def check_nanmedian_none(self):
+        """Check nanmedian when no values are nan."""
+        m = stats.stats.nanmedian(self.X)
+        assert_approx_equal(m, stats.stats.median(self.X))
+
+    def check_nanmedian_some(self):
+        """Check nanmedian when some values only are nan."""
+        m = stats.stats.nanmedian(self.Xsome)
+        assert_approx_equal(m, stats.stats.median(self.Xsomet))
+
+    def check_nanmedian_all(self):
+        """Check nanmedian when all values are nan."""
+        m = stats.stats.nanmedian(self.Xall)
+        assert numpy.isnan(m)
+
class test_corr(NumpyTestCase):
""" W.II.D. Compute a correlation matrix on all the variables.

```