[Scipy-svn] r3207 - in trunk/Lib/linalg: . tests

scipy-svn@scip... scipy-svn@scip...
Mon Jul 30 09:58:04 CDT 2007


Author: cdavid
Date: 2007-07-30 09:57:57 -0500 (Mon, 30 Jul 2007)
New Revision: 3207

Modified:
   trunk/Lib/linalg/iterative.py
   trunk/Lib/linalg/tests/test_iterative.py
Log:
Copy initial values in iterative solvers to avoid overwriting input arguments. See ticket #470

Modified: trunk/Lib/linalg/iterative.py
===================================================================
--- trunk/Lib/linalg/iterative.py	2007-07-29 00:54:10 UTC (rev 3206)
+++ trunk/Lib/linalg/iterative.py	2007-07-30 14:57:57 UTC (rev 3207)
@@ -12,6 +12,7 @@
 __all__ = ['bicg','bicgstab','cg','cgs','gmres','qmr']
 from scipy.linalg import _iterative
 import numpy as sb
+import copy
 
 try:
     False, True
@@ -148,9 +149,10 @@
     if maxiter is None:
         maxiter = n*10
 
-    x = x0
-    if x is None:
+    if x0 is None:
         x = sb.zeros(n)
+    else:
+        x = copy.copy(x0)
 
     if xtype is None:
         try:
@@ -266,10 +268,12 @@
     if maxiter is None:
         maxiter = n*10
 
-    x = x0
-    if x is None:
+    if x0 is None:
         x = sb.zeros(n)
+    else:
+        x = copy.copy(x0)
 
+
     if xtype is None:
         try:
             atyp = A.dtype.char
@@ -376,10 +380,12 @@
     if maxiter is None:
         maxiter = n*10
 
-    x = x0
-    if x is None:
+    if x0 is None:
         x = sb.zeros(n)
+    else:
+        x = copy.copy(x0)
 
+
     if xtype is None:
         try:
             atyp = A.dtype.char
@@ -486,9 +492,10 @@
     if maxiter is None:
         maxiter = n*10
 
-    x = x0
-    if x is None:
+    if x0 is None:
         x = sb.zeros(n)
+    else:
+        x = copy.copy(x0)
 
     if xtype is None:
         try:
@@ -598,9 +605,10 @@
     if maxiter is None:
         maxiter = n*10
 
-    x = x0
-    if x is None:
+    if x0 is None:
         x = sb.zeros(n)
+    else:
+        x = copy.copy(x0)
 
     if xtype is None:
         try:
@@ -710,9 +718,10 @@
     if maxiter is None:
         maxiter = n*10
 
-    x = x0
-    if x is None:
+    if x0 is None:
         x = sb.zeros(n)
+    else:
+        x = copy.copy(x0)
 
     if xtype is None:
         try:

Modified: trunk/Lib/linalg/tests/test_iterative.py
===================================================================
--- trunk/Lib/linalg/tests/test_iterative.py	2007-07-29 00:54:10 UTC (rev 3206)
+++ trunk/Lib/linalg/tests/test_iterative.py	2007-07-30 14:57:57 UTC (rev 3207)
@@ -45,27 +45,39 @@
         b = self.b
 
     def check_cg(self):
+        bx0 = self.x0.copy()
         x, info = cg(self.A, self.b, self.x0, callback=callback)
+        assert_array_equal(bx0, self.x0)
         assert norm(dot(self.A, x) - self.b) < 5*self.tol
 
     def check_bicg(self):
+        bx0 = self.x0.copy()
         x, info = bicg(self.A, self.b, self.x0, callback=callback)
+        assert_array_equal(bx0, self.x0)
         assert norm(dot(self.A, x) - self.b) < 5*self.tol
 
     def check_cgs(self):
+        bx0 = self.x0.copy()
         x, info = cgs(self.A, self.b, self.x0, callback=callback)
+        assert_array_equal(bx0, self.x0)
         assert norm(dot(self.A, x) - self.b) < 5*self.tol
 
     def check_bicgstab(self):
+        bx0 = self.x0.copy()
         x, info = bicgstab(self.A, self.b, self.x0, callback=callback)
+        assert_array_equal(bx0, self.x0)
         assert norm(dot(self.A, x) - self.b) < 5*self.tol
 
     def check_gmres(self):
+        bx0 = self.x0.copy()
         x, info = gmres(self.A, self.b, self.x0, callback=callback)
+        assert_array_equal(bx0, self.x0)
         assert norm(dot(self.A, x) - self.b) < 5*self.tol
 
     def check_qmr(self):
+        bx0 = self.x0.copy()
         x, info = qmr(self.A, self.b, self.x0, callback=callback)
+        assert_array_equal(bx0, self.x0)
         assert norm(dot(self.A, x) - self.b) < 5*self.tol
 
 if __name__ == "__main__":



More information about the Scipy-svn mailing list