[Scipy-svn] r6983 - in trunk/scipy/ndimage: . tests

scipy-svn@scip... scipy-svn@scip...
Wed Dec 1 00:07:30 CST 2010


Author: warren.weckesser
Date: 2010-12-01 00:07:29 -0600 (Wed, 01 Dec 2010)
New Revision: 6983

Modified:
   trunk/scipy/ndimage/measurements.py
   trunk/scipy/ndimage/tests/test_measurements.py
Log:
BUG: ndimage: correctly deal with 'labels' data types that are compatible with bincount (tickets #1254 and #1242).  Use numpy.unique instead of numpy.unique1d.

Modified: trunk/scipy/ndimage/measurements.py
===================================================================
--- trunk/scipy/ndimage/measurements.py	2010-11-29 23:41:08 UTC (rev 6982)
+++ trunk/scipy/ndimage/measurements.py	2010-12-01 06:07:29 UTC (rev 6983)
@@ -374,6 +374,13 @@
 
     return output
 
+def _safely_castable_to_int(dt):
+    """Test whether the numpy data type `dt` can be safely cast to an int."""
+    int_size = np.dtype(int).itemsize
+    safe = ((np.issubdtype(dt, int) and dt.itemsize <= int_size) or
+            (np.issubdtype(dt, np.unsignedinteger) and dt.itemsize < int_size))
+    return safe
+
 def _stats(input, labels=None, index=None, centered=False):
     '''returns count, sum, and optionally (sum - centre)^2 by label'''
 
@@ -396,36 +403,35 @@
     if numpy.isscalar(index):
         return single_group(input[labels == index])
 
-    counts = numpy.bincount(labels.ravel())
-    sums = numpy.bincount(labels.ravel(), weights=input.ravel())
-
     def _sum_centered(labels):
         means = sums / counts
         centered_input = input - means[labels]
-        return numpy.bincount(labels,
+        bc = numpy.bincount(labels,
                               weights=(centered_input * \
                                        centered_input.conjugate()).ravel())
+        return bc
 
-    # remap labels to unique integers if necessary, or if the largest
+    # Remap labels to unique integers if necessary, or if the largest
     # label is larger than the number of values.
 
-    if not numpy.issubdtype(labels.dtype, (numpy.int, np.unsignedinteger)) or \
-           (labels.min() < 0) or (labels.max() > labels.size):
-        unique_labels, new_labels = numpy.unique1d(labels, return_inverse=True)
-
+    if (not _safely_castable_to_int(labels.dtype) or
+            labels.min() < 0 or labels.max() > labels.size):
+        unique_labels, new_labels = numpy.unique(labels, return_inverse=True)
+        counts = numpy.bincount(new_labels) 
+        sums = numpy.bincount(new_labels, weights=input.ravel())
         if centered:
-            sums_c, sums, counts = _sum_centered(new_labels)
-
+            sums_c = _sum_centered(new_labels)
         idxs = numpy.searchsorted(unique_labels, index)
         # make all of idxs valid
         idxs[idxs >= unique_labels.size] = 0
         found = (unique_labels[idxs] == index)
     else:
-        # labels are an integer type, and there aren't too many, so
-        # call bincount directly.
+        # labels are an integer type allowed by bincount, and there aren't too
+        # many, so call bincount directly.
+        counts = numpy.bincount(labels.ravel())
+        sums = numpy.bincount(labels.ravel(), weights=input.ravel())
         if centered:
             sums_c = _sum_centered(labels.ravel())
-
         # make sure all index values are valid
         idxs = numpy.asanyarray(index, numpy.int).copy()
         found = (idxs >= 0) & (idxs < counts.size)
@@ -645,6 +651,7 @@
 def _select(input, labels = None, index = None, find_min=False, find_max=False, find_min_positions=False, find_max_positions=False):
     '''returns min, max, or both, plus positions if requested'''
 
+    input = numpy.asanyarray(input)
 
     find_positions = find_min_positions or find_max_positions
     positions = None
@@ -691,10 +698,10 @@
 
     # remap labels to unique integers if necessary, or if the largest
     # label is larger than the number of values.
-    if ((not numpy.issubdtype(labels.dtype, numpy.int)) or
-        (labels.min() < 0) or (labels.max() > labels.size)):
+    if (_safely_castable_to_int(labels.dtype) or
+            labels.min() < 0 or labels.max() > labels.size):
         # remap labels, and indexes
-        unique_labels, labels = numpy.unique1d(labels, return_inverse=True)
+        unique_labels, labels = numpy.unique(labels, return_inverse=True)
         idxs = numpy.searchsorted(unique_labels, index)
 
         # make all of idxs valid

Modified: trunk/scipy/ndimage/tests/test_measurements.py
===================================================================
--- trunk/scipy/ndimage/tests/test_measurements.py	2010-11-29 23:41:08 UTC (rev 6982)
+++ trunk/scipy/ndimage/tests/test_measurements.py	2010-12-01 06:07:29 UTC (rev 6983)
@@ -1,6 +1,6 @@
 from numpy.testing import assert_, assert_array_almost_equal, assert_equal, \
-                          assert_almost_equal, \
-                          run_module_suite
+                          assert_almost_equal, assert_array_equal, \
+                          run_module_suite, TestCase
 import numpy as np
 
 import scipy.ndimage as ndimage
@@ -10,6 +10,90 @@
          np.int64, np.uint64,
          np.float32, np.float64]
 
+
+class Test_measurements_stats(TestCase):
+    """ndimage.measurements._stats() is a utility function used by other functions."""
+
+    def test_a(self):
+        x = [0,1,2,6]
+        labels = [0,0,1,1]
+        index = [0,1]
+        counts, sums = ndimage.measurements._stats(x, labels=labels, index=index)
+        assert_array_equal(counts, [2, 2])
+        assert_array_equal(sums, [1.0, 8.0])
+
+    def test_b(self):
+        # Same data as test_a, but different labels.  The label 9 exceeds the
+        # length of 'labels', so this test will follow a different code path.
+        x = [0,1,2,6]
+        labels = [0,0,9,9]
+        index = [0,9]
+        counts, sums = ndimage.measurements._stats(x, labels=labels, index=index)
+        assert_array_equal(counts, [2, 2])
+        assert_array_equal(sums, [1.0, 8.0])
+
+    def test_a_centered(self):
+        x = [0,1,2,6]
+        labels = [0,0,1,1]
+        index = [0,1]
+        counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
+                                    index=index, centered=True)
+        assert_array_equal(counts, [2, 2])
+        assert_array_equal(sums, [1.0, 8.0])
+        assert_array_equal(centers, [0.5, 8.0])
+
+    def test_b_centered(self):
+        x = [0,1,2,6]
+        labels = [0,0,9,9]
+        index = [0,9]
+        counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
+                                    index=index, centered=True)
+        assert_array_equal(counts, [2, 2])
+        assert_array_equal(sums, [1.0, 8.0])
+        assert_array_equal(centers, [0.5, 8.0])
+
+    def test_nonint_labels(self):
+        x = [0,1,2,6]
+        labels = [0.0, 0.0, 9.0, 9.0]
+        index = [0.0, 9.0]
+        counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
+                                    index=index, centered=True)
+        assert_array_equal(counts, [2, 2])
+        assert_array_equal(sums, [1.0, 8.0])
+        assert_array_equal(centers, [0.5, 8.0])
+
+
+class Test_measurements_select(TestCase):
+    """ndimage.measurements._select() is a utility function used by other functions."""
+
+    def test_basic(self):
+        x = [0,1,6,2]
+        cases = [
+            ([0,0,1,1], [0,1]),                 # "Small" integer labels
+            ([0,0,9,9], [0,9]),                 # A label larger than len(labels)
+            ([0.0,0.0,7.0,7.0], [0.0, 7.0]),    # Non-integer labels
+        ]
+        for labels, index in cases:
+            result = ndimage.measurements._select(x, labels=labels, index=index)
+            assert_(len(result) == 0)
+            result = ndimage.measurements._select(x, labels=labels, index=index, find_max=True)
+            assert_(len(result) == 1)
+            assert_array_equal(result[0], [1, 6])
+            result = ndimage.measurements._select(x, labels=labels, index=index, find_min=True)
+            assert_(len(result) == 1)
+            assert_array_equal(result[0], [0, 2])
+            result = ndimage.measurements._select(x, labels=labels, index=index,
+                                find_min=True, find_min_positions=True)
+            assert_(len(result) == 2)
+            assert_array_equal(result[0], [0, 2])
+            assert_array_equal(result[1], [0, 3])
+            result = ndimage.measurements._select(x, labels=labels, index=index,
+                                find_max=True, find_max_positions=True)
+            assert_(len(result) == 2)
+            assert_array_equal(result[0], [1, 6])
+            assert_array_equal(result[1], [1, 2])        
+
+
 def test_label01():
     "label 1"
     data = np.ones([])



More information about the Scipy-svn mailing list