[Scipy-svn] r6836 - trunk/scipy/ndimage

scipy-svn@scip... scipy-svn@scip...
Thu Oct 14 10:36:54 CDT 2010


Author: stefan
Date: 2010-10-14 10:36:54 -0500 (Thu, 14 Oct 2010)
New Revision: 6836

Modified:
   trunk/scipy/ndimage/measurements.py
Log:
ENH: ndimage.measurements: refactor _stats.

Modified: trunk/scipy/ndimage/measurements.py
===================================================================
--- trunk/scipy/ndimage/measurements.py	2010-10-13 21:44:46 UTC (rev 6835)
+++ trunk/scipy/ndimage/measurements.py	2010-10-14 15:36:54 UTC (rev 6836)
@@ -274,12 +274,13 @@
 
     return output
 
-def _stats(input, labels = None, index = None, do_sum2=False):
-    '''returns count, sum, and optionally sum^2 by label'''
+def _stats(input, labels=None, index=None, centered=False):
+    '''returns count, sum, and optionally (sum - centre)^2 by label'''
 
     def single_group(vals):
-        if do_sum2:
-            return vals.size, vals.sum(), (vals * vals.conjugate()).sum()
+        if centered:
+            vals_c = vals# - vals.mean()
+            return vals.size, vals.sum(), (vals_c * vals_c.conjugate()).sum()
         else:
             return vals.size, vals.sum()
 
@@ -295,16 +296,25 @@
     if numpy.isscalar(index):
         return single_group(input[labels == index])
 
+    counts = numpy.bincount(labels.ravel())
+    sums = numpy.bincount(labels.ravel(), weights=input.ravel())
+
+    def _sum_centered(labels):
+        means = sums / counts
+        centered_input = input# - means[labels]
+        return numpy.bincount(labels,
+                              weights=(centered_input * \
+                                       centered_input.conjugate()).ravel())
+
     # remap labels to unique integers if necessary, or if the largest
     # label is larger than the number of values.
+
     if ((not numpy.issubdtype(labels.dtype, numpy.int)) or
         (labels.min() < 0) or (labels.max() > labels.size)):
         unique_labels, new_labels = numpy.unique1d(labels, return_inverse=True)
 
-        counts = numpy.bincount(new_labels)
-        sums = numpy.bincount(new_labels, weights=input.ravel())
-        if do_sum2:
-            sums2 = numpy.bincount(new_labels, weights=(input * input.conjugate()).ravel())
+        if centered:
+            sums_c, sums, counts = _sum_centered(new_labels)
 
         idxs = numpy.searchsorted(unique_labels, index)
         # make all of idxs valid
@@ -313,25 +323,25 @@
     else:
         # labels are an integer type, and there aren't too many, so
         # call bincount directly.
-        counts = numpy.bincount(labels.ravel())
-        sums = numpy.bincount(labels.ravel(), weights=input.ravel())
-        if do_sum2:
-            sums2 = numpy.bincount(labels.ravel(), weights=(input * input.conjugate()).ravel())
+        if centered:
+            sums_c = _sum_centered(labels.ravel())
 
         # make sure all index values are valid
         idxs = numpy.asanyarray(index, numpy.int).copy()
         found = (idxs >= 0) & (idxs < counts.size)
-        idxs[~ found] = 0
+        idxs[~found] = 0
 
     counts = counts[idxs]
-    counts[~ found] = 0
+    counts[~found] = 0
     sums = sums[idxs]
-    sums[~ found] = 0
-    if not do_sum2:
+    sums[~found] = 0
+
+    if not centered:
         return (counts, sums)
-    sums2 = sums2[idxs]
-    sums2[~ found] = 0
-    return (counts, sums, sums2)
+    else:
+        sums_c = sums_c[idxs]
+        sums_c[~found] = 0
+        return (counts, sums, sums_c)
 
 
 def sum(input, labels = None, index = None):
@@ -398,9 +408,9 @@
     none, all values where label is greater than zero are used.
     """
 
-    count, sum, sum2 = _stats(input, labels, index, do_sum2=True)
+    count, sum, sum_c = _stats(input, labels, index, centered=True)
     mean = sum / numpy.asanyarray(count).astype(numpy.float)
-    mean2 = sum2 / numpy.asanyarray(count).astype(numpy.float)
+    mean2 = sum_c / numpy.asanyarray(count).astype(numpy.float)
 
     return mean2 - (mean * mean.conjugate())
 



More information about the Scipy-svn mailing list