[Scipy-svn] r6558 - in branches/0.8.x/scipy/optimize: . tests

scipy-svn@scip... scipy-svn@scip...
Mon Jun 21 11:20:41 CDT 2010


Author: oliphant
Date: 2010-06-21 11:20:41 -0500 (Mon, 21 Jun 2010)
New Revision: 6558

Modified:
   branches/0.8.x/scipy/optimize/minpack.py
   branches/0.8.x/scipy/optimize/tests/test_minpack.py
Log:
Fix up curve_fit in the 0.8.x branch and add a test in that branch.

Modified: branches/0.8.x/scipy/optimize/minpack.py
===================================================================
--- branches/0.8.x/scipy/optimize/minpack.py	2010-06-21 16:18:01 UTC (rev 6557)
+++ branches/0.8.x/scipy/optimize/minpack.py	2010-06-21 16:20:41 UTC (rev 6558)
@@ -416,7 +416,7 @@
     res = leastsq(func, p0, args=args, full_output=1, **kw)
     (popt, pcov, infodict, errmsg, ier) = res
 
-    if ier != 1:
+    if ier not in [1,2,3,4]:
         msg = "Optimal parameters not found: " + errmsg
         raise RuntimeError(msg)
 

Modified: branches/0.8.x/scipy/optimize/tests/test_minpack.py
===================================================================
--- branches/0.8.x/scipy/optimize/tests/test_minpack.py	2010-06-21 16:18:01 UTC (rev 6557)
+++ branches/0.8.x/scipy/optimize/tests/test_minpack.py	2010-06-21 16:20:41 UTC (rev 6558)
@@ -7,7 +7,7 @@
 from numpy import array, float64
 
 from scipy import optimize
-from scipy.optimize.minpack import fsolve, leastsq
+from scipy.optimize.minpack import fsolve, leastsq, curve_fit
 
 class TestFSolve(TestCase):
     def pressure_network(self, flow_rates, Qtot, k):
@@ -122,5 +122,31 @@
         assert_(ier in (1,2,3,4), 'solution not found: %s'%mesg)
         assert_array_equal(p0, p0_copy)
 
+class TestCurveFit(TestCase):
+    def setUp(self):
+        self.y = array([1.0, 3.2, 9.5, 13.7])
+        self.x = array([1.0, 2.0, 3.0, 4.0])
+
+    def test_one_argument(self):
+        def func(x,a):
+            return x**a
+        popt, pcov = curve_fit(func, self.x, self.y)
+        assert len(popt)==1
+        assert pcov.shape==(1,1)
+        assert_almost_equal(popt[0], 1.9149, decimal=4)
+        assert_almost_equal(pcov[0,0], 0.0016, decimal=4)
+
+    def test_two_argument(self):
+        def func(x, a, b):
+            return b*x**a
+        popt, pcov = curve_fit(func, self.x, self.y)
+        assert len(popt)==2
+        assert pcov.shape==(2,2)
+        assert_array_almost_equal(popt, [1.7989, 1.1642], decimal=4)
+        assert_array_almost_equal(pcov, [[0.0852, -0.1260],[-0.1260, 0.1912]], decimal=4)
+
+
+
+
 if __name__ == "__main__":
     run_module_suite()



More information about the Scipy-svn mailing list