[Numpy-svn] r6524 - trunk/numpy/linalg/tests

numpy-svn@scip... numpy-svn@scip...
Mon Mar 2 08:18:19 CST 2009


Author: cdavid
Date: 2009-03-02 08:18:15 -0600 (Mon, 02 Mar 2009)
New Revision: 6524

Modified:
   trunk/numpy/linalg/tests/test_linalg.py
Log:
Abstract away dtype for norm test.

Modified: trunk/numpy/linalg/tests/test_linalg.py
===================================================================
--- trunk/numpy/linalg/tests/test_linalg.py	2009-03-02 14:18:01 UTC (rev 6523)
+++ trunk/numpy/linalg/tests/test_linalg.py	2009-03-02 14:18:15 UTC (rev 6524)
@@ -1,6 +1,7 @@
 """ Test functions for linalg module
 """
 
+import numpy as np
 from numpy.testing import *
 from numpy import array, single, double, csingle, cdouble, dot, identity
 from numpy import multiply, atleast_2d, inf, asarray, matrix
@@ -257,17 +258,19 @@
         evalues, evectors = linalg.eig(a)
         assert_almost_equal(ev, evalues)
 
-class TestNorm(TestCase):
+class _TestNorm(TestCase):
+    dt = None
     def test_empty(self):
         assert_equal(norm([]), 0.0)
-        assert_equal(norm(array([], dtype = double)), 0.0)
-        assert_equal(norm(atleast_2d(array([], dtype = double))), 0.0)
+        assert_equal(norm(array([], dtype=self.dt)), 0.0)
+        assert_equal(norm(atleast_2d(array([], dtype=self.dt))), 0.0)
 
     def test_vector(self):
         a = [1.0,2.0,3.0,4.0]
         b = [-1.0,-2.0,-3.0,-4.0]
         c = [-1.0, 2.0,-3.0, 4.0]
-        for v in (a,array(a),b,array(b),c,array(c)):
+        for v in (a,array(a, dtype=self.dt),b,array(b, dtype=self.dt),c,array(c,
+                  dtype=self.dt)):
             assert_almost_equal(norm(v), 30**0.5)
             assert_almost_equal(norm(v,inf), 4.0)
             assert_almost_equal(norm(v,-inf), 1.0)
@@ -283,19 +286,22 @@
         self.assertRaises(ValueError, norm, array([1., 2., 3.]), 'fro')
 
     def test_matrix(self):
-        A = matrix([[1.,3.],[5.,7.]], dtype=single)
-        A = matrix([[1.,3.],[5.,7.]], dtype=single)
+        A = matrix([[1.,3.],[5.,7.]], dtype=self.dt)
+        A = matrix([[1.,3.],[5.,7.]], dtype=self.dt)
         assert_almost_equal(norm(A), 84**0.5)
         assert_almost_equal(norm(A,'fro'), 84**0.5)
         assert_almost_equal(norm(A,inf), 12.0)
         assert_almost_equal(norm(A,-inf), 4.0)
         assert_almost_equal(norm(A,1), 10.0)
         assert_almost_equal(norm(A,-1), 6.0)
-        assert_almost_equal(norm(A,2), 9.12310563)
-        assert_almost_equal(norm(A,-2), 0.87689437)
+        assert_almost_equal(norm(A,2), 9.1231056256176615)
+        assert_almost_equal(norm(A,-2), 0.87689437438234041)
 
         self.assertRaises(ValueError, norm, A, 'nofro')
         self.assertRaises(ValueError, norm, A, -3)
 
+class TestNormDouble(_TestNorm):
+    dt = np.double
+
 if __name__ == "__main__":
     run_module_suite()



More information about the Numpy-svn mailing list