[Scipy-svn] r6554 - in trunk/scipy/optimize: . tests

scipy-svn@scip... scipy-svn@scip...
Mon Jun 21 10:53:30 CDT 2010


Author: oliphant
Date: 2010-06-21 10:53:30 -0500 (Mon, 21 Jun 2010)
New Revision: 6554

Modified:
   trunk/scipy/optimize/minpack.py
   trunk/scipy/optimize/tests/test_minpack.py
Log:
Add tests for curve fit.  Fix fsolve to be consistent with leastsq (i.e. don't special-case the n==1 call, but always return an array).

Modified: trunk/scipy/optimize/minpack.py
===================================================================
--- trunk/scipy/optimize/minpack.py	2010-06-20 23:50:09 UTC (rev 6553)
+++ trunk/scipy/optimize/minpack.py	2010-06-21 15:53:30 UTC (rev 6554)
@@ -151,9 +151,6 @@
             except KeyError:
                 raise errors['unknown'][1](errors['unknown'][0])
 
-    if n == 1:
-        retval = (retval[0][0],) + retval[1:]
-
     if full_output:
         try:
             return retval + (errors[info][0],)  # Return all + the message
@@ -405,6 +402,7 @@
     else:
         func = _weighted_general_function
         args += (1.0/asarray(sigma),)
+
     res = leastsq(func, p0, args=args, full_output=1, **kw)
     (popt, pcov, infodict, errmsg, ier) = res
 

Modified: trunk/scipy/optimize/tests/test_minpack.py
===================================================================
--- trunk/scipy/optimize/tests/test_minpack.py	2010-06-20 23:50:09 UTC (rev 6553)
+++ trunk/scipy/optimize/tests/test_minpack.py	2010-06-21 15:53:30 UTC (rev 6554)
@@ -7,7 +7,7 @@
 from numpy import array, float64
 
 from scipy import optimize
-from scipy.optimize.minpack import fsolve, leastsq
+from scipy.optimize.minpack import fsolve, leastsq, curve_fit
 
 class TestFSolve(TestCase):
     def pressure_network(self, flow_rates, Qtot, k):
@@ -122,5 +122,31 @@
         assert_(ier in (1,2,3,4), 'solution not found: %s'%mesg)
         assert_array_equal(p0, p0_copy)
 
+class TestCurveFit(TestCase):
+    def setUp(self):
+        self.y = array([1.0, 3.2, 9.5, 13.7])
+        self.x = array([1.0, 2.0, 3.0, 4.0])
+
+    def test_one_argument(self):
+        def func(x,a):
+            return x**a
+        popt, pcov = curve_fit(func, self.x, self.y)
+        assert len(popt)==1
+        assert pcov.shape==(1,1)
+        assert_almost_equal(popt[0], 1.9149, decimal=4)
+        assert_almost_equal(pcov[0,0], 0.0016, decimal=4)
+
+    def test_two_argument(self):
+        def func(x, a, b):
+            return b*x**a
+        popt, pcov = curve_fit(func, self.x, self.y)
+        assert len(popt)==2
+        assert pcov.shape==(2,2)
+        assert_array_almost_equal(popt, [1.7989, 1.1642], decimal=4)
+        assert_array_almost_equal(pcov, [[0.0852, -0.1260],[-0.1260, 0.1912]], decimal=4)
+
+
+
+
 if __name__ == "__main__":
     run_module_suite()



More information about the Scipy-svn mailing list