[Scipy-svn] r3196 - in trunk/Lib/linalg: src tests

scipy-svn@scip... scipy-svn@scip...
Thu Jul 26 08:47:39 CDT 2007


Author: cdavid
Date: 2007-07-26 08:47:33 -0500 (Thu, 26 Jul 2007)
New Revision: 3196

Modified:
   trunk/Lib/linalg/src/lu.f
   trunk/Lib/linalg/tests/test_decomp.py
Log:
* Add more tests for LU decomp: rectangular matrices are tested.
* Both single and double are tested, too
* Correct swap order in fortran wrappers for LU functions (solve #427)
* Set right dimension in fortran wrapper when swapping L when premut_l is true
(solve crash #468)



Modified: trunk/Lib/linalg/src/lu.f
===================================================================
--- trunk/Lib/linalg/src/lu.f	2007-07-26 12:00:22 UTC (rev 3195)
+++ trunk/Lib/linalg/src/lu.f	2007-07-26 13:47:33 UTC (rev 3196)
@@ -43,12 +43,12 @@
  10      continue
  20   continue
       if (permute_l.ne.0) then
-         call dlaswp(n,l,m,1,k,piv,1)
+         call dlaswp(k,l,m,1,k,piv,-1)
       else
          do 25 i=1,m
             p(i,i)=1d0
  25       continue
-         call dlaswp(m,p,m,1,k,piv,1)
+         call dlaswp(m,p,m,1,k,piv,-1)
       endif
       end
 
@@ -90,12 +90,12 @@
  10      continue
  20   continue
       if (permute_l.ne.0) then
-         call zlaswp(n,l,m,1,k,piv,1)
+         call zlaswp(k,l,m,1,k,piv,-1)
       else
          do 25 i=1,m
             p(i,i)=1d0
  25       continue
-         call dlaswp(m,p,m,1,k,piv,1)
+         call dlaswp(m,p,m,1,k,piv,-1)
       endif
       end
 
@@ -137,12 +137,12 @@
  10      continue
  20   continue
       if (permute_l.ne.0) then
-         call slaswp(n,l,m,1,k,piv,1)
+         call slaswp(k,l,m,1,k,piv,-1)
       else
          do 25 i=1,m
             p(i,i)=1e0
  25       continue
-         call slaswp(m,p,m,1,k,piv,1)
+         call slaswp(m,p,m,1,k,piv,-1)
       endif
       end
 
@@ -184,11 +184,11 @@
  10      continue
  20   continue
       if (permute_l.ne.0) then
-         call claswp(n,l,m,1,k,piv,1)
+         call claswp(k,l,m,1,k,piv,-1)
       else
          do 25 i=1,m
             p(i,i)=1e0
  25       continue
-         call slaswp(m,p,m,1,k,piv,1)
+         call slaswp(m,p,m,1,k,piv,-1)
       endif
       end

Modified: trunk/Lib/linalg/tests/test_decomp.py
===================================================================
--- trunk/Lib/linalg/tests/test_decomp.py	2007-07-26 12:00:22 UTC (rev 3195)
+++ trunk/Lib/linalg/tests/test_decomp.py	2007-07-26 13:47:33 UTC (rev 3196)
@@ -18,12 +18,11 @@
 from numpy.testing import *
 
 set_package_path()
-from scipy.linalg import eig,eigvals,lu,svd,svdvals,cholesky,qr,schur,rsf2csf
-from scipy.linalg import lu_solve,lu_factor,solve,diagsvd,hessenberg,rq
-from scipy.linalg import eig_banded,eigvals_banded
-from scipy.linalg.flapack import dgbtrf, dgbtrs, zgbtrf, zgbtrs
-from scipy.linalg.flapack import dsbev, dsbevd, dsbevx, zhbevd, zhbevx
-
+from linalg import eig,eigvals,lu,svd,svdvals,cholesky,qr,schur,rsf2csf
+from linalg import lu_solve,lu_factor,solve,diagsvd,hessenberg,rq
+from linalg import eig_banded,eigvals_banded
+from linalg.flapack import dgbtrf, dgbtrs, zgbtrf, zgbtrs
+from linalg.flapack import dsbev, dsbevd, dsbevx, zhbevd, zhbevx
 restore_path()
 
 from numpy import *
@@ -439,22 +438,87 @@
 
 class test_lu(NumpyTestCase):
 
+    def __init__(self, *args, **kw):
+        NumpyTestCase.__init__(self, *args, **kw)
+
+        self.a = array([[1,2,3],[1,2,3],[2,5,6]])
+        self.ca = array([[1,2,3],[1,2,3],[2,5j,6]])
+        # Those matrices are more robust to detect problems in permutation
+        # matrices than the ones above
+        self.b = array([[1,2,3],[4,5,6],[7,8,9]])
+        self.cb = array([[1j,2j,3j],[4j,5j,6j],[7j,8j,9j]])
+
+        # Reectangular matrices
+        self.hrect = array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 12, 12]])
+        self.chrect = 1.j * array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 12, 12]])
+
+        self.vrect = array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 12, 12]])
+        self.cvrect = 1.j * array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 12, 12]])
+
+        # Medium sizes matrices
+        self.med = rand(30, 40)
+        self.cmed = rand(30, 40) + 1.j * rand(30, 40)
+
+    def _test_common(self, data):
+        p,l,u = lu(data)
+        assert_array_almost_equal(dot(dot(p,l),u),data)
+        pl,u = lu(data,permute_l=1)
+        assert_array_almost_equal(dot(pl,u),data)
+
+    # Simple tests
     def check_simple(self):
-        a = [[1,2,3],[1,2,3],[2,5,6]]
-        p,l,u = lu(a)
-        assert_array_almost_equal(dot(dot(p,l),u),a)
-        pl,u = lu(a,permute_l=1)
-        assert_array_almost_equal(dot(pl,u),a)
+        self._test_common(self.a)
 
     def check_simple_complex(self):
-        a = [[1,2,3],[1,2,3],[2,5j,6]]
-        p,l,u = lu(a)
-        assert_array_almost_equal(dot(dot(p,l),u),a)
-        pl,u = lu(a,permute_l=1)
-        assert_array_almost_equal(dot(pl,u),a)
+        self._test_common(self.ca)
 
-    #XXX: need more tests
+    def check_simple2(self):
+        self._test_common(self.b)
 
+    def check_simple2_complex(self):
+        self._test_common(self.cb)
+
+    # rectangular matrices tests
+    def check_hrectangular(self):
+        self._test_common(self.hrect)
+
+    def check_vrectangular(self):
+        self._test_common(self.vrect)
+
+    def check_hrectangular_complex(self):
+        self._test_common(self.chrect)
+
+    def check_vrectangular_complex(self):
+        self._test_common(self.cvrect)
+
+    # Bigger matrices
+    def check_medium1(self, level = 2):
+        """Check lu decomposition on medium size, rectangular matrix."""
+        self._test_common(self.med)
+
+    def check_medium1_complex(self, level = 2):
+        """Check lu decomposition on medium size, rectangular matrix."""
+        self._test_common(self.cmed)
+
+class test_lu_single(test_lu):
+    """LU testers for single precision, real and double"""
+    def __init__(self, *args, **kw):
+        test_lu.__init__(self, *args, **kw)
+
+        self.a = self.a.astype(float32)
+        self.ca = self.ca.astype(complex64)
+        self.b = self.b.astype(float32)
+        self.cb = self.cb.astype(complex64)
+
+        self.hrect = self.hrect.astype(float32)
+        self.chrect = self.hrect.astype(complex64)
+
+        self.vrect = self.vrect.astype(float32)
+        self.cvrect = self.vrect.astype(complex64)
+
+        self.med = self.vrect.astype(float32)
+        self.cmed = self.vrect.astype(complex64)
+
 class test_lu_solve(NumpyTestCase):
     def check_lu(self):
         a = random((10,10))



More information about the Scipy-svn mailing list