[Scipy-svn] r4778 - in branches/spatial/scipy/spatial: . tests

scipy-svn@scip... scipy-svn@scip...
Sun Oct 5 03:00:06 CDT 2008

```Author: peridot
Date: 2008-10-05 02:59:51 -0500 (Sun, 05 Oct 2008)
New Revision: 4778

Modified:
branches/spatial/scipy/spatial/kdtree.py
branches/spatial/scipy/spatial/tests/test_kdtree.py
Log:

Modified: branches/spatial/scipy/spatial/kdtree.py
===================================================================
--- branches/spatial/scipy/spatial/kdtree.py	2008-10-05 05:51:47 UTC (rev 4777)
+++ branches/spatial/scipy/spatial/kdtree.py	2008-10-05 07:59:51 UTC (rev 4778)
@@ -132,12 +132,14 @@
class leafnode(node):
def __init__(self, idx):
self.idx = idx
+            self.children = len(idx)
class innernode(node):
def __init__(self, split_dim, split, less, greater):
self.split_dim = split_dim
self.split = split
self.less = less
self.greater = greater
+            self.children = less.children+greater.children

def __build(self, idx, maxes, mins):
if len(idx)<=self.leafsize:
@@ -494,3 +496,79 @@
return results

+    def count_neighbors(self, other, r, p=2.):
+        """Count how many nearby pairs can be formed.
+
+        Count the number of pairs (x1,x2) can be formed, with x1 drawn
+        from self and x2 drawn from other, and where distance(x1,x2,p)<=r.
+        This is the "two-point correlation" described in Gray and Moore 2000,
+        "N-body problems in statistical learning", and the code here is based
+        on their algorithm.
+
+        Parameters
+        ==========
+
+        other : KDTree
+
+        r : float or one-dimensional array of floats
+            The radius to produce a count for. Multiple radii are searched with a single
+            tree traversal.
+        p : float, 1<=p<=infinity
+            Which Minkowski p-norm to use
+
+        Returns
+        =======
+
+        result : integer or one-dimensional array of integers
+            The number of pairs. Note that this is internally stored in a numpy int,
+            and so may overflow if very large (two billion).
+        """
+
+        def traverse(node1, rect1, node2, rect2, idx):
+            min_r = rect1.min_distance_rectangle(rect2,p)
+            max_r = rect1.max_distance_rectangle(rect2,p)
+            c_greater = r[idx]>max_r
+            result[idx[c_greater]] += node1.children*node2.children
+            idx = idx[(min_r<=r[idx]) & (r[idx]<=max_r)]
+            if len(idx)==0:
+                return
+
+            if isinstance(node1,KDTree.leafnode):
+                if isinstance(node2,KDTree.leafnode):
+                    ds = distance(self.data[node1.idx][:,np.newaxis,:],
+                                  other.data[node2.idx][np.newaxis,:,:],
+                                  p).ravel()
+                    ds.sort()
+                    result[idx] += np.searchsorted(ds,r[idx],side='right')
+                else:
+                    less, greater = rect2.split(node2.split_dim, node2.split)
+                    traverse(node1, rect1, node2.less, less, idx)
+                    traverse(node1, rect1, node2.greater, greater, idx)
+            else:
+                if isinstance(node2,KDTree.leafnode):
+                    less, greater = rect1.split(node1.split_dim, node1.split)
+                    traverse(node1.less, less, node2, rect2, idx)
+                    traverse(node1.greater, greater, node2, rect2, idx)
+                else:
+                    less1, greater1 = rect1.split(node1.split_dim, node1.split)
+                    less2, greater2 = rect2.split(node2.split_dim, node2.split)
+                    traverse(node1.less,less1,node2.less,less2,idx)
+                    traverse(node1.less,less1,node2.greater,greater2,idx)
+                    traverse(node1.greater,greater1,node2.less,less2,idx)
+                    traverse(node1.greater,greater1,node2.greater,greater2,idx)
+        R1 = Rectangle(self.maxes, self.mins)
+        R2 = Rectangle(other.maxes, other.mins)
+        if np.shape(r) == ():
+            r = np.array([r])
+            result = np.zeros(1,dtype=int)
+            traverse(self.tree, R1, other.tree, R2, np.arange(1))
+            return result[0]
+        elif len(np.shape(r))==1:
+            r = np.asarray(r)
+            n, = r.shape
+            result = np.zeros(n,dtype=int)
+            traverse(self.tree, R1, other.tree, R2, np.arange(n))
+            return result
+        else:
+            raise ValueError("r must be either a single value or a one-dimensional array of values")
+

Modified: branches/spatial/scipy/spatial/tests/test_kdtree.py
===================================================================
--- branches/spatial/scipy/spatial/tests/test_kdtree.py	2008-10-05 05:51:47 UTC (rev 4777)
+++ branches/spatial/scipy/spatial/tests/test_kdtree.py	2008-10-05 07:59:51 UTC (rev 4778)
@@ -311,3 +311,28 @@
x = np.random.randn(10,1,3)
y = np.random.randn(1,7,3)
assert_equal(distance(x,y).shape,(10,7))
+
+class test_count_neighbors:
+
+    def setUp(self):
+        n = 100
+        k = 4
+        self.T1 = KDTree(np.random.randn(n,k),leafsize=2)
+        self.T2 = KDTree(np.random.randn(n,k),leafsize=2)
+
+        r = 0.2
+        assert_equal(self.T1.count_neighbors(self.T2, r),
+                np.sum([len(l) for l in self.T1.query_ball_tree(self.T2,r)]))
+
+        r = 1000
+        assert_equal(self.T1.count_neighbors(self.T2, r),
+                np.sum([len(l) for l in self.T1.query_ball_tree(self.T2,r)]))
+