[Scipy-svn] r6268 - trunk/scipy/sparse/linalg/eigen/arpack

scipy-svn@scip... scipy-svn@scip...
Fri Mar 26 00:34:49 CDT 2010


Author: cdavid
Date: 2010-03-26 00:34:48 -0500 (Fri, 26 Mar 2010)
New Revision: 6268

Modified:
   trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
Log:
REF: put arpack unsymmetric solver paramaters checking into separate object.

Modified: trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
===================================================================
--- trunk/scipy/sparse/linalg/eigen/arpack/arpack.py	2010-03-24 02:39:31 UTC (rev 6267)
+++ trunk/scipy/sparse/linalg/eigen/arpack/arpack.py	2010-03-26 05:34:48 UTC (rev 6268)
@@ -51,7 +51,88 @@
 _type_conv = {'f':'s', 'd':'d', 'F':'c', 'D':'z'}
 _ndigits = {'f':5, 'd':12, 'F':5, 'D':12}
 
+class _ArpackParams(object):
+    def __init__(self, n, k, tp, mode="symmetric", sigma=None,
+                 ncv=None, v0=None, maxiter=None, which="LM", tol=0):
+        if k <= 0:
+            raise ValueError("k must be positive, k=%d" % k)
+        if k == n:
+            raise ValueError("k must be less than rank(A), k=%d" % k)
 
+        if maxiter is None:
+            maxiter = n * 10
+        if maxiter <= 0:
+            raise ValueError("maxiter must be positive, maxiter=%d" % maxiter)
+
+        if tp not in 'fdFD':
+            raise ValueError("matrix type must be 'f', 'd', 'F', or 'D'")
+
+        if v0 is not None:
+            self.resid = v0
+            info = 1
+        else:
+            self.resid = np.zeros(n, tp)
+            info = 0
+
+        if sigma is not None:
+            raise NotImplementedError("shifted eigenproblem not supported yet")
+
+        if ncv is None:
+            ncv = 2 * k + 1
+        ncv = min(ncv, n)
+
+        if ncv > n or ncv < k:
+            raise ValueError("ncv must be k<=ncv<=n, ncv=%s" % ncv)
+
+        if not which in ["LM", "SM", "LR", "SR", "LI", "SI"]:
+            raise ValueError("Parameter which must be one of %s" % ' '.join(whiches))
+
+        ltr = _type_conv[tp]
+
+        self.v = np.zeros((n, ncv), tp) # holds Ritz vectors
+        self.rwork = None # Only used for unsymmetric, complex solver
+
+        if mode == "unsymmetric":
+            self.workd = np.zeros(3 * n, tp)
+            self.workl = np.zeros(3 * ncv * ncv + 6 * ncv, tp)
+            self.solver = _arpack.__dict__[ltr + 'naupd']
+            self.extract = _arpack.__dict__[ltr + 'neupd']
+
+            if tp in 'FD':
+                self.rwork = np.zeros(ncv, tp.lower())
+
+            self.ipntr = np.zeros(14, "int")
+        elif mode == "symmetric":
+            self.workd = np.zeros(3 * n, tp)
+            self.workl = np.zeros(ncv * (ncv + 8), tp)
+            self.solver = _arpack.__dict__[ltr + 'saupd']
+            self.extract = _arpack.__dict__[ltr + 'seupd']
+
+            self.ipntr = np.zeros(11, "int")
+        else:
+            raise ValueError("Unrecognized mode %s" % mode)
+
+        self.iparam = np.zeros(11, "int")
+
+        # set solver mode and parameters
+        # only supported mode is 1: Ax=lx
+        ishfts = 1
+        mode1 = 1
+        self.iparam[0] = ishfts
+        self.iparam[2] = maxiter
+        self.iparam[6] = mode1
+
+        self.n = n
+        self.mode = mode
+        self.tol = tol
+        self.k = k
+        self.maxiter = maxiter
+        self.ncv = ncv
+        self.which = which
+        self.tp = tp
+        self.info = info
+        self.bmat = 'I'
+
 def eigen(A, k=6, M=None, sigma=None, which='LM', v0=None,
           ncv=None, maxiter=None, tol=0,
           return_eigenvectors=True):
@@ -133,168 +214,113 @@
         raise ValueError('expected square matrix (shape=%s)' % A.shape)
     n = A.shape[0]
 
-    # guess type
-    typ = A.dtype.char
-    if typ not in 'fdFD':
-        raise ValueError("matrix type must be 'f', 'd', 'F', or 'D'")
+    params = _ArpackParams(n, k, A.dtype.char, "unsymmetric", sigma,
+                           ncv, v0, maxiter, which, tol)
 
     if M is not None:
         raise NotImplementedError("generalized eigenproblem not supported yet")
-    if sigma is not None:
-        raise NotImplementedError("shifted eigenproblem not supported yet")
 
-
-    # some defaults
-    if ncv is None:
-        ncv=2*k+1
-    ncv=min(ncv,n)
-    if maxiter==None:
-        maxiter=n*10
-    # assign starting vector
-    if v0 is not None:
-        resid=v0
-        info=1
-    else:
-        resid = np.zeros(n,typ)
-        info=0
-
-
-    # some sanity checks
-    if k <= 0:
-        raise ValueError("k must be positive, k=%d"%k)
-    if k == n:
-        raise ValueError("k must be less than rank(A), k=%d"%k)
-    if maxiter <= 0:
-        raise ValueError("maxiter must be positive, maxiter=%d"%maxiter)
-    whiches=['LM','SM','LR','SR','LI','SI']
-    if which not in whiches:
-        raise ValueError("which must be one of %s"%' '.join(whiches))
-    if ncv > n or ncv < k:
-        raise ValueError("ncv must be k<=ncv<=n, ncv=%s"%ncv)
-
-    # assign solver and postprocessor
-    ltr = _type_conv[typ]
-    eigsolver = _arpack.__dict__[ltr+'naupd']
-    eigextract = _arpack.__dict__[ltr+'neupd']
-
-    v = np.zeros((n,ncv),typ) # holds Ritz vectors
-    workd = np.zeros(3*n,typ) # workspace
-    workl = np.zeros(3*ncv*ncv+6*ncv,typ) # workspace
-    iparam = np.zeros(11,'int') # problem parameters
-    ipntr = np.zeros(14,'int') # pointers into workspaces
     ido = 0
 
-    if typ in 'FD':
-        rwork = np.zeros(ncv,typ.lower())
-
-    # set solver mode and parameters
-    # only supported mode is 1: Ax=lx
-    ishfts = 1
-    mode1 = 1
-    bmat = 'I'
-    iparam[0] = ishfts
-    iparam[2] = maxiter
-    iparam[6] = mode1
-
     while True:
-        if typ in 'fd':
-            ido,resid,v,iparam,ipntr,info =\
-                eigsolver(ido,bmat,which,k,tol,resid,v,iparam,ipntr,
-                          workd,workl,info)
+        if params.tp in 'fd':
+            ido, params.resid, params.v, params.iparam, params.ipntr, params.info = \
+                params.solver(ido, params.bmat, params.which, params.k, params.tol,
+                        params.resid, params.v, params.iparam, params.ipntr,
+                        params.workd, params.workl, params.info)
         else:
-            ido,resid,v,iparam,ipntr,info =\
-                eigsolver(ido,bmat,which,k,tol,resid,v,iparam,ipntr,
-                          workd,workl,rwork,info)
+            ido, params.resid, params.v, params.iparam, params.ipntr, params.info =\
+                params.solver(ido, params.bmat, params.which, params.k, params.tol,
+                        params.resid, params.v, params.iparam, params.ipntr,
+                        params.workd, params.workl, params.rwork, params.info)
 
-        xslice = slice(ipntr[0]-1, ipntr[0]-1+n)
-        yslice = slice(ipntr[1]-1, ipntr[1]-1+n)
+        xslice = slice(params.ipntr[0]-1, params.ipntr[0]-1+n)
+        yslice = slice(params.ipntr[1]-1, params.ipntr[1]-1+n)
         if ido == -1:
             # initialization
-            workd[yslice]=A.matvec(workd[xslice])
+            params.workd[yslice] = A.matvec(params.workd[xslice])
         elif ido == 1:
             # compute y=Ax
-            workd[yslice]=A.matvec(workd[xslice])
+            params.workd[yslice] = A.matvec(params.workd[xslice])
         else:
             break
 
-    if  info < -1 :
-        raise RuntimeError("Error info=%d in arpack"%info)
-        return None
-    if info == -1:
-        warnings.warn("Maximum number of iterations taken: %s"%iparam[2])
-#    if iparam[3] != k:
-#        warnings.warn("Only %s eigenvalues converged"%iparam[3])
+    if params.info < -1 :
+        raise RuntimeError("Error info=%d in arpack" % params.info)
+    elif params.info == -1:
+        warnings.warn("Maximum number of iterations taken: %s" % self.iparam[2])
 
-
     # now extract eigenvalues and (optionally) eigenvectors
     rvec = return_eigenvectors
     ierr = 0
     howmny = 'A' # return all eigenvectors
-    sselect = np.zeros(ncv,'int') # unused
+    sselect = np.zeros(params.ncv, 'int') # unused
     sigmai = 0.0 # no shifts, not implemented
     sigmar = 0.0 # no shifts, not implemented
-    workev = np.zeros(3*ncv,typ)
+    workev = np.zeros(3 * params.ncv, params.tp)
 
-    if typ in 'fd':
-        dr=np.zeros(k+1,typ)
-        di=np.zeros(k+1,typ)
-        zr=np.zeros((n,k+1),typ)
-        dr,di,zr,info=\
-            eigextract(rvec,howmny,sselect,sigmar,sigmai,workev,
-                   bmat,which,k,tol,resid,v,iparam,ipntr,
-                   workd,workl,info)
+    if params.tp in 'fd':
+        dr = np.zeros(k+1, params.tp)
+        di = np.zeros(k+1, params.tp)
+        zr = np.zeros((n, k+1), params.tp)
+        dr, di, zr, params.info=\
+            params.extract(rvec, howmny, sselect, sigmar, sigmai, workev,
+                   params.bmat, params.which, k, params.tol, params.resid,
+                   params.v, params.iparam, params.ipntr,
+                   params.workd, params.workl, params.info)
 
         # The ARPACK nonsymmetric real and double interface (s,d)naupd return
         # eigenvalues and eigenvectors in real (float,double) arrays.
 
         # Build complex eigenvalues from real and imaginary parts
-        d=dr+1.0j*di
+        d = dr + 1.0j * di
 
         # Arrange the eigenvectors: complex eigenvectors are stored as
         # real,imaginary in consecutive columns
-        z=zr.astype(typ.upper())
-        eps=np.finfo(typ).eps
-        i=0
+        z = zr.astype(params.tp.upper())
+        eps = np.finfo(params.tp).eps
+        i = 0
         while i<=k:
             # check if complex
-            if abs(d[i].imag)>eps:
+            if abs(d[i].imag) > eps:
                 # assume this is a complex conjugate pair with eigenvalues
                 # in consecutive columns
-                z[:,i]=zr[:,i]+1.0j*zr[:,i+1]
-                z[:,i+1]=z[:,i].conjugate()
-                i+=1
-            i+=1
+                z[:,i] = zr[:,i] + 1.0j * zr[:,i+1]
+                z[:,i+1] = z[:,i].conjugate()
+                i +=1
+            i += 1
 
         # Now we have k+1 possible eigenvalues and eigenvectors
         # Return the ones specified by the keyword "which"
-        nreturned=iparam[4] # number of good eigenvalues returned
-        if nreturned==k:    # we got exactly how many eigenvalues we wanted
-            d=d[:k]
-            z=z[:,:k]
+        nreturned = params.iparam[4] # number of good eigenvalues returned
+        if nreturned == k:    # we got exactly how many eigenvalues we wanted
+            d = d[:k]
+            z = z[:,:k]
         else:   # we got one extra eigenvalue (likely a cc pair, but which?)
             # cut at approx precision for sorting
-            rd=np.round(d,decimals=_ndigits[typ])
-            if which in ['LR','SR']:
-                ind=np.argsort(rd.real)
+            rd = np.round(d, decimals = _ndigits[params.tp])
+            if params.which in ['LR','SR']:
+                ind = np.argsort(rd.real)
             elif which in ['LI','SI']:
                 # for LI,SI ARPACK returns largest,smallest abs(imaginary) why?
-                ind=np.argsort(abs(rd.imag))
+                ind = np.argsort(abs(rd.imag))
             else:
-                ind=np.argsort(abs(rd))
-            if which in ['LR','LM','LI']:
-                d=d[ind[-k:]]
-                z=z[:,ind[-k:]]
-            if which in ['SR','SM','SI']:
-                d=d[ind[:k]]
-                z=z[:,ind[:k]]
+                ind = np.argsort(abs(rd))
+            if params.which in ['LR','LM','LI']:
+                d = d[ind[-k:]]
+                z = z[:,ind[-k:]]
+            if params.which in ['SR','SM','SI']:
+                d = d[ind[:k]]
+                z = z[:,ind[:k]]
 
 
     else:
         # complex is so much simpler...
-        d,z,info =\
-              eigextract(rvec,howmny,sselect,sigmar,workev,
-                         bmat,which,k,tol,resid,v,iparam,ipntr,
-                         workd,workl,rwork,ierr)
+        d, z, params.info =\
+                params.extract(rvec, howmny, sselect, sigmar, workev,
+                       params.bmat, params.which, k, params.tol, params.resid,
+                       params.v, params.iparam, params.ipntr,
+                       params.workd, params.workl, params.rwork, ierr)
 
 
 



More information about the Scipy-svn mailing list