[Scipy-svn] r4769 - branches/spatial/scipy/spatial/tests

scipy-svn@scip... scipy-svn@scip...
Sat Oct 4 19:43:10 CDT 2008


Author: peridot
Date: 2008-10-04 19:43:08 -0500 (Sat, 04 Oct 2008)
New Revision: 4769

Modified:
   branches/spatial/scipy/spatial/tests/test_kdtree.py
Log:
Nosification of tests; some tests are applied more generally now.


Modified: branches/spatial/scipy/spatial/tests/test_kdtree.py
===================================================================
--- branches/spatial/scipy/spatial/tests/test_kdtree.py	2008-10-04 22:33:54 UTC (rev 4768)
+++ branches/spatial/scipy/spatial/tests/test_kdtree.py	2008-10-05 00:43:08 UTC (rev 4769)
@@ -5,64 +5,17 @@
 import numpy as np
 from scipy.spatial import KDTree, distance
 
-class CheckSmall(NumpyTestCase):
-    def setUp(self):
-        self.data = np.array([[0,0,0],
-                              [0,0,1],
-                              [0,1,0],
-                              [0,1,1],
-                              [1,0,0],
-                              [1,0,1],
-                              [1,1,0],
-                              [1,1,1]])
-        self.kdtree = KDTree(self.data)
-
+class ConsistencyTests:
     def test_nearest(self):
-        assert_array_equal(
-                self.kdtree.query((0,0,0.1), 1),
-                (0.1,0))
-    def test_nearest_two(self):
-        assert_array_equal(
-                self.kdtree.query((0,0,0.1), 2),
-                ([0.1,0.9],[0,1]))
-class CheckSmallNonLeaf(NumpyTestCase):
-    def setUp(self):
-        self.data = np.array([[0,0,0],
-                              [0,0,1],
-                              [0,1,0],
-                              [0,1,1],
-                              [1,0,0],
-                              [1,0,1],
-                              [1,1,0],
-                              [1,1,1]])
-        self.kdtree = KDTree(self.data,leafsize=1)
-
-    def test_nearest(self):
-        assert_array_equal(
-                self.kdtree.query((0,0,0.1), 1),
-                (0.1,0))
-    def test_nearest_two(self):
-        assert_array_equal(
-                self.kdtree.query((0,0,0.1), 2),
-                ([0.1,0.9],[0,1]))
-
-class CheckRandom(NumpyTestCase):
-    def setUp(self):
-        self.n = 1000
-        self.k = 4
-        self.data = np.random.randn(self.n, self.k)
-        self.kdtree = KDTree(self.data)
-
-    def test_nearest(self):
-        x = np.random.randn(self.k)
+        x = self.x
         d, i = self.kdtree.query(x, 1)
         assert_almost_equal(d**2,np.sum((x-self.data[i])**2))
         eps = 1e-8
         assert np.all(np.sum((self.data-x[np.newaxis,:])**2,axis=1)>d**2-eps)
         
     def test_m_nearest(self):
-        x = np.random.randn(self.k)
-        m = 10
+        x = self.x
+        m = self.m
         dd, ii = self.kdtree.query(x, m)
         d = np.amax(dd)
         i = ii[np.argmax(dd)]
@@ -71,8 +24,8 @@
         assert_equal(np.sum(np.sum((self.data-x[np.newaxis,:])**2,axis=1)<d**2+eps),m)
 
     def test_points_near(self):
-        x = np.random.randn(self.k)
-        d = 0.2
+        x = self.x
+        d = self.d
         dd, ii = self.kdtree.query(x, k=self.kdtree.n, distance_upper_bound=d)
         eps = 1e-8
         hits = 0
@@ -85,8 +38,8 @@
         assert_equal(np.sum(np.sum((self.data-x[np.newaxis,:])**2,axis=1)<d**2+eps),hits)
 
     def test_points_near_l1(self):
-        x = np.random.randn(self.k)
-        d = 0.2
+        x = self.x
+        d = self.d
         dd, ii = self.kdtree.query(x, k=self.kdtree.n, p=1, distance_upper_bound=d)
         eps = 1e-8
         hits = 0
@@ -98,8 +51,8 @@
             assert near_d<d+eps, "near_d=%g should be less than %g" % (near_d,d)
         assert_equal(np.sum(distance(self.data,x,1)<d+eps),hits)
     def test_points_near_linf(self):
-        x = np.random.randn(self.k)
-        d = 0.2
+        x = self.x
+        d = self.d
         dd, ii = self.kdtree.query(x, k=self.kdtree.n, p=np.inf, distance_upper_bound=d)
         eps = 1e-8
         hits = 0
@@ -112,13 +65,67 @@
         assert_equal(np.sum(distance(self.data,x,np.inf)<d+eps),hits)
 
     def test_approx(self):
-        x = np.random.randn(self.k)
-        m = 10
+        x = self.x
+        m = self.m
         eps = 0.1
         d_real, i_real = self.kdtree.query(x, m)
         d, i = self.kdtree.query(x, m, eps=eps)
         assert np.all(d<=d_real*(1+eps))
 
+    
+class test_random(ConsistencyTests):
+    def setUp(self):
+        self.n = 1000
+        self.k = 4
+        self.data = np.random.randn(self.n, self.k)
+        self.kdtree = KDTree(self.data)
+        self.x = np.random.randn(self.k)
+        self.d = 0.2
+        self.m = 10
+
+class test_small(ConsistencyTests):
+    def setUp(self):
+        self.data = np.array([[0,0,0],
+                              [0,0,1],
+                              [0,1,0],
+                              [0,1,1],
+                              [1,0,0],
+                              [1,0,1],
+                              [1,1,0],
+                              [1,1,1]])
+        self.kdtree = KDTree(self.data)
+        self.n = self.kdtree.n
+        self.k = self.kdtree.k
+        self.x = np.random.randn(3)
+        self.d = 0.5
+        self.m = 4
+
+    def test_nearest(self):
+        assert_array_equal(
+                self.kdtree.query((0,0,0.1), 1),
+                (0.1,0))
+    def test_nearest_two(self):
+        assert_array_equal(
+                self.kdtree.query((0,0,0.1), 2),
+                ([0.1,0.9],[0,1]))
+class test_small_nonleaf(test_small):
+    def setUp(self):
+        self.data = np.array([[0,0,0],
+                              [0,0,1],
+                              [0,1,0],
+                              [0,1,1],
+                              [1,0,0],
+                              [1,0,1],
+                              [1,1,0],
+                              [1,1,1]])
+        self.kdtree = KDTree(self.data,leafsize=1)
+        self.n = self.kdtree.n
+        self.k = self.kdtree.k
+        self.x = np.random.randn(3)
+        self.d = 0.5
+        self.m = 4
+
+
 class CheckVectorization(NumpyTestCase):
     def setUp(self):
         self.data = np.array([[0,0,0],
@@ -170,10 +177,3 @@
         assert isinstance(i[0,0],list)
 
 
-
-    
-if __name__=='__main__':
-    import unittest
-    unittest.main()
-
-



More information about the Scipy-svn mailing list