[Scipy-svn] r2228 - in trunk/Lib/linalg: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Sep 24 02:57:54 CDT 2006


Author: rkern
Date: 2006-09-24 02:57:51 -0500 (Sun, 24 Sep 2006)
New Revision: 2228

Modified:
   trunk/Lib/linalg/decomp.py
   trunk/Lib/linalg/generic_flapack.pyf
   trunk/Lib/linalg/tests/test_decomp.py
Log:
Add 'economy' QR decomposition from patch in #220

Modified: trunk/Lib/linalg/decomp.py
===================================================================
--- trunk/Lib/linalg/decomp.py	2006-09-24 07:11:23 UTC (rev 2227)
+++ trunk/Lib/linalg/decomp.py	2006-09-24 07:57:51 UTC (rev 2228)
@@ -6,10 +6,10 @@
 # additions by Travis Oliphant, March 2002
 # additions by Eric Jones,      June 2002
 # additions by Johannes Loehnert, June 2006
+# additions by Bart Vandereycken, June 2006
 
-
 __all__ = ['eig','eigh','eig_banded','eigvals','eigvalsh', 'eigvals_banded',
-           'lu','svd','svdvals','diagsvd','cholesky','qr',
+           'lu','svd','svdvals','diagsvd','cholesky','qr','qr_old',
            'schur','rsf2csf','lu_factor','cho_factor','cho_solve','orth',
            'hessenberg']
 
@@ -17,7 +17,7 @@
 import basic
 
 from warnings import warn
-from lapack import get_lapack_funcs
+from lapack import get_lapack_funcs, find_best_lapack_type
 from blas import get_blas_funcs
 from flinalg import get_flinalg_funcs
 from scipy.linalg import calc_lwork
@@ -581,8 +581,90 @@
         raise TypeError, msg
     return b
 
+def qr(a,overwrite_a=0,lwork=None,econ=False,mode='qr'):
+    """QR decomposition of an M x N matrix a.
 
-def qr(a,overwrite_a=0,lwork=None):
+    Description:
+
+      Find a unitary matrix, q, and an upper-trapezoidal matrix r
+      such that q * r = a
+
+    Inputs:
+
+      a -- the matrix
+      overwrite_a=0 -- if non-zero then discard the contents of a,
+                     i.e. a is used as a work array if possible.
+
+      lwork=None -- >= shape(a)[1]. If None (or -1) compute optimal
+                    work array size.
+      econ=False -- computes the skinny or economy-size QR decomposition
+                    only useful when M>N
+      mode='qr' -- if 'qr' then return both q and r; if 'r' then just return r
+
+    Outputs:
+      q,r  - if mode=='qr'
+      r    - if mode=='r'       
+                    
+    """
+    a1 = asarray_chkfinite(a)
+    if len(a1.shape) != 2:
+        raise ValueError("expected 2D array")
+    M, N = a1.shape
+    overwrite_a = overwrite_a or (_datanotshared(a1,a))    
+
+    geqrf, = get_lapack_funcs(('geqrf',),(a1,))
+    if lwork is None or lwork == -1:
+        # get optimal work array
+        qr,tau,work,info = geqrf(a1,lwork=-1,overwrite_a=1)
+        lwork = work[0]
+
+    qr,tau,work,info = geqrf(a1,lwork=lwork,overwrite_a=overwrite_a)
+    if info<0:
+        raise ValueError("illegal value in %-th argument of internal geqrf" 
+            % -info)
+
+    if not econ or M<N:
+        R = basic.triu(qr)
+    else:
+        R = basic.triu(qr[0:N,0:N])
+        
+    if mode=='r':
+        return R
+    
+    if find_best_lapack_type((a1,))[0]=='s' or find_best_lapack_type((a1,))[0]=='d':
+        gor_un_gqr, = get_lapack_funcs(('orgqr',),(qr,))
+    else:
+        gor_un_gqr, = get_lapack_funcs(('ungqr',),(qr,))
+
+    
+    if M<N:
+        # get optimal work array
+        Q,work,info = gor_un_gqr(qr[:,0:M],tau,lwork=-1,overwrite_a=1)
+        lwork = work[0]
+        Q,work,info = gor_un_gqr(qr[:,0:M],tau,lwork=lwork,overwrite_a=1)
+    elif econ:
+        # get optimal work array
+        Q,work,info = gor_un_gqr(qr,tau,lwork=-1,overwrite_a=1)
+        lwork = work[0]
+        Q,work,info = gor_un_gqr(qr,tau,lwork=lwork,overwrite_a=1)      
+    else:       
+        t = qr.dtype.char
+        qqr = numpy.empty((M,M),dtype=t)
+        qqr[:,0:N]=qr
+        # get optimal work array
+        Q,work,info = gor_un_gqr(qqr,tau,lwork=-1,overwrite_a=1)
+        lwork = work[0]
+        Q,work,info = gor_un_gqr(qqr,tau,lwork=lwork,overwrite_a=1)     
+
+    if info < 0:
+        raise ValueError("illegal value in %-th argument of internal gorgqr" 
+            % -info)
+        
+    return Q, R
+
+
+
+def qr_old(a,overwrite_a=0,lwork=None):
     """QR decomposition of an M x N matrix a.
 
     Description:

Modified: trunk/Lib/linalg/generic_flapack.pyf
===================================================================
--- trunk/Lib/linalg/generic_flapack.pyf	2006-09-24 07:11:23 UTC (rev 2227)
+++ trunk/Lib/linalg/generic_flapack.pyf	2006-09-24 07:57:51 UTC (rev 2228)
@@ -332,6 +332,48 @@
      integer intent(out) :: info
    end subroutine <tchar=s,d,c,z>geqrf
 
+   subroutine <tchar=s,d>orgqr(m,n,k,a,tau,work,lwork,info)
+
+   ! q,work,info = orgqr(a,lwork=3*n,overwrite_a=0)
+   ! Generates an M-by-N real matrix Q with orthonormal columns,
+   ! which is defined as the first N columns of a product of K elementary
+   ! reflectors of order M (e.g. output of geqrf)
+
+     callstatement (*f2py_func)(&m,&n,&k,a,&m,tau,work,&lwork,&info)
+     callprotoargument int*,int*,int*,<type_in_c>*,int*,<type_in_c>*,<type_in_c>*,int*,int*
+
+     integer intent(hide),depend(a):: m = shape(a,0)
+     integer intent(hide),depend(a):: n = shape(a,1)
+     integer intent(hide),depend(tau):: k = shape(tau,0)
+     <type_in> dimension(m,n),intent(in,out,copy,out=q) :: a
+     <type_in> dimension(k),intent(in) :: tau
+
+     integer optional,intent(in),depend(n),check(lwork>=n||lwork==-1) :: lwork=3*n
+     <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
+     integer intent(out) :: info
+   end subroutine <tchar=s,d>orgqr
+
+   subroutine <tchar=c,z>ungqr(m,n,k,a,tau,work,lwork,info)
+
+   ! q,work,info = ungqr(a,lwork=3*n,overwrite_a=0)
+   ! Generates an M-by-N complex matrix Q with unitary columns,
+   ! which is defined as the first N columns of a product of K elementary
+   ! reflectors of order M (e.g. output of geqrf)
+
+     callstatement (*f2py_func)(&m,&n,&k,a,&m,tau,work,&lwork,&info)
+     callprotoargument int*,int*,int*,<type_in_c>*,int*,<type_in_c>*,<type_in_c>*,int*,int*
+
+     integer intent(hide),depend(a):: m = shape(a,0)
+     integer intent(hide),depend(a):: n = shape(a,1)
+     integer intent(hide),depend(tau):: k = shape(tau,0)
+     <type_in> dimension(m,n),intent(in,out,copy,out=q) :: a
+     <type_in> dimension(k),intent(in) :: tau
+
+     integer optional,intent(in),depend(n),check(lwork>=n||lwork==-1) :: lwork=3*n
+     <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
+     integer intent(out) :: info
+   end subroutine <tchar=c,z>ungqr
++
    subroutine <tchar=s,d>geev(compute_vl,compute_vr,n,a,wr,wi,vl,ldvl,vr,ldvr,work,lwork,info)
 
      ! wr,wi,vl,vr,info = geev(a,compute_vl=1,compute_vr=1,lwork=4*n,overwrite_a=0)

Modified: trunk/Lib/linalg/tests/test_decomp.py
===================================================================
--- trunk/Lib/linalg/tests/test_decomp.py	2006-09-24 07:11:23 UTC (rev 2227)
+++ trunk/Lib/linalg/tests/test_decomp.py	2006-09-24 07:57:51 UTC (rev 2228)
@@ -599,6 +599,28 @@
         assert_array_almost_equal(dot(transpose(q),q),identity(3))
         assert_array_almost_equal(dot(q,r),a)
 
+    def check_simple_trap(self):
+        a = [[8,2,3],[2,9,3]]
+        q,r = qr(a)
+        assert_array_almost_equal(dot(transpose(q),q),identity(2))
+        assert_array_almost_equal(dot(q,r),a)
+
+    def check_simple_tall(self):
+        # full version
+        a = [[8,2],[2,9],[5,3]]
+        q,r = qr(a)
+        assert_array_almost_equal(dot(transpose(q),q),identity(3))
+        assert_array_almost_equal(dot(q,r),a)
+
+    def check_simple_tall_e(self):
+        # economy version
+        a = [[8,2],[2,9],[5,3]]
+        q,r = qr(a,econ=True)
+        assert_array_almost_equal(dot(transpose(q),q),identity(2))
+        assert_array_almost_equal(dot(q,r),a)
+        assert_equal(q.shape, (3,2))
+        assert_equal(r.shape, (2,2))
+
     def check_simple_complex(self):
         a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
         q,r = qr(a)
@@ -613,6 +635,37 @@
             assert_array_almost_equal(dot(transpose(q),q),identity(n))
             assert_array_almost_equal(dot(q,r),a)
 
+    def check_random_tall(self):
+        # full version
+        m = 200
+        n = 100
+        for k in range(2):
+            a = random([m,n])
+            q,r = qr(a)
+            assert_array_almost_equal(dot(transpose(q),q),identity(m))
+            assert_array_almost_equal(dot(q,r),a)
+
+    def check_random_tall_e(self):
+        # economy version
+        m = 200
+        n = 100
+        for k in range(2):
+            a = random([m,n])
+            q,r = qr(a,econ=True)
+            assert_array_almost_equal(dot(transpose(q),q),identity(n))
+            assert_array_almost_equal(dot(q,r),a)
+            assert_equal(q.shape, (m,n))
+            assert_equal(r.shape, (n,n))
+
+    def check_random_trap(self):
+        m = 100
+        n = 200
+        for k in range(2):
+            a = random([m,n])
+            q,r = qr(a)
+            assert_array_almost_equal(dot(transpose(q),q),identity(m))
+            assert_array_almost_equal(dot(q,r),a)
+
     def check_random_complex(self):
         n = 20
         for k in range(2):



More information about the Scipy-svn mailing list