[Scipy-svn] r3209 - in trunk/Lib/optimize: . tests

scipy-svn@scip... scipy-svn@scip...
Mon Jul 30 14:17:54 CDT 2007


Author: dmitrey.kroshko
Date: 2007-07-30 14:17:36 -0500 (Mon, 30 Jul 2007)
New Revision: 3209

Modified:
   trunk/Lib/optimize/optimize.py
   trunk/Lib/optimize/tests/test_optimize.py
Log:
hanges in optimize.brent, proposed by Alan Isaac, related to ticket 285
4 brent() tests were added



Modified: trunk/Lib/optimize/optimize.py
===================================================================
--- trunk/Lib/optimize/optimize.py	2007-07-30 15:01:04 UTC (rev 3208)
+++ trunk/Lib/optimize/optimize.py	2007-07-30 19:17:36 UTC (rev 3209)
@@ -1268,19 +1268,152 @@
     else:
         return xf
 
+class Brent:
+    #need to rethink design of __init__
+    def __init__(self, func, args=(), tol=1.48e-8, maxiter=500, full_output=0):
+        self.func = func
+        self.args = args
+        self.tol = tol
+        self.maxiter = maxiter
+        self._mintol = 1.0e-11
+        self._cg = 0.3819660
+        self.xmin = None
+        self.fval = None
+        self.iter = 0
+        self.funcalls = 0
+        
+    #need to rethink design of set_bracket (new options, etc)
+    def set_bracket(self, brack = None):
+        self.brack = brack
+    def get_bracket_info(self):
+        #set up
+        func = self.func
+        args = self.args
+        brack = self.brack
+        ### BEGIN core bracket_info code ###
+        ### carefully DOCUMENT any CHANGES in core ##
+        if brack is None:
+            xa,xb,xc,fa,fb,fc,funcalls = bracket(func, args=args)
+        elif len(brack) == 2:
+            xa,xb,xc,fa,fb,fc,funcalls = bracket(func, xa=brack[0], xb=brack[1], args=args)
+        elif len(brack) == 3:
+            xa,xb,xc = brack
+            if (xa > xc):  # swap so xa < xc can be assumed
+                dum = xa; xa=xc; xc=dum
+            assert ((xa < xb) and (xb < xc)), "Not a bracketing interval."
+            fa = func(*((xa,)+args))
+            fb = func(*((xb,)+args))
+            fc = func(*((xc,)+args))
+            assert ((fb<fa) and (fb < fc)), "Not a bracketing interval."
+            funcalls = 3
+        else:
+            raise ValueError, "Bracketing interval must be length 2 or 3 sequence."
+        ### END core bracket_info code ###
+
+        return xa,xb,xc,fa,fb,fc,funcalls
+
+    def optimize(self):
+        #set up for optimization
+        func = self.func
+        xa,xb,xc,fa,fb,fc,funcalls = self.get_bracket_info()
+        _mintol = self._mintol
+        _cg = self._cg
+        #################################
+        #BEGIN CORE ALGORITHM
+        #we are making NO CHANGES in this
+        #################################
+        x=w=v=xb
+        fw=fv=fx=func(*((x,)+self.args))
+        if (xa < xc):
+            a = xa; b = xc
+        else:
+            a = xc; b = xa
+        deltax= 0.0
+        funcalls = 1
+        iter = 0
+        while (iter < self.maxiter):
+            tol1 = self.tol*abs(x) + _mintol
+            tol2 = 2.0*tol1
+            xmid = 0.5*(a+b)
+            if abs(x-xmid) < (tol2-0.5*(b-a)):  # check for convergence
+                xmin=x; fval=fx
+                break
+            if (abs(deltax) <= tol1):
+                if (x>=xmid): deltax=a-x       # do a golden section step
+                else: deltax=b-x
+                rat = _cg*deltax
+            else:                              # do a parabolic step
+                tmp1 = (x-w)*(fx-fv)
+                tmp2 = (x-v)*(fx-fw)
+                p = (x-v)*tmp2 - (x-w)*tmp1;
+                tmp2 = 2.0*(tmp2-tmp1)
+                if (tmp2 > 0.0): p = -p
+                tmp2 = abs(tmp2)
+                dx_temp = deltax
+                deltax= rat
+                # check parabolic fit
+                if ((p > tmp2*(a-x)) and (p < tmp2*(b-x)) and (abs(p) < abs(0.5*tmp2*dx_temp))):
+                    rat = p*1.0/tmp2        # if parabolic step is useful.
+                    u = x + rat
+                    if ((u-a) < tol2 or (b-u) < tol2):
+                        if xmid-x >= 0: rat = tol1
+                        else: rat = -tol1
+                else:
+                    if (x>=xmid): deltax=a-x # if it's not do a golden section step
+                    else: deltax=b-x
+                    rat = _cg*deltax
+
+            if (abs(rat) < tol1):            # update by at least tol1
+                if rat >= 0: u = x + tol1
+                else: u = x - tol1
+            else:
+                u = x + rat
+            fu = func(*((u,)+self.args))      # calculate new output value
+            funcalls += 1
+
+            if (fu > fx):                 # if it's bigger than current
+                if (u<x): a=u
+                else: b=u
+                if (fu<=fw) or (w==x):
+                    v=w; w=u; fv=fw; fw=fu
+                elif (fu<=fv) or (v==x) or (v==w):
+                    v=u; fv=fu
+            else:
+                if (u >= x): a = x
+                else: b = x
+                v=w; w=x; x=u
+                fv=fw; fw=fx; fx=fu
+
+            iter += 1
+        #################################
+        #END CORE ALGORITHM
+        #################################
+
+        self.xmin = x
+        self.fval = fx
+        self.iter = iter
+        self.funcalls = funcalls
+
+    def get_result(self, full_output=False):
+        if full_output:
+            return self.xmin, self.fval, self.iter, self.funcalls
+        else:
+            return self.xmin
+
+
 def brent(func, args=(), brack=None, tol=1.48e-8, full_output=0, maxiter=500):
     """ Given a function of one-variable and a possible bracketing interval,
     return the minimum of the function isolated to a fractional precision of
     tol. A bracketing interval is a triple (a,b,c) where (a<b<c) and
-    func(b) < func(a),func(c).  If bracket is two numbers then they are
+    func(b) < func(a),func(c).  If bracket is two numbers (a,c) then they are
     assumed to be a starting interval for a downhill bracket search
-    (see bracket)
+    (see bracket); it doesn't always mean that obtained solution will satisfy a<=x<=c.
 
     Uses inverse parabolic interpolation when possible to speed up convergence
     of golden section method.
 
 
-    See also:
+    :SeeAlso:
 
       fmin, fmin_powell, fmin_cg,
              fmin_bfgs, fmin_ncg -- multivariate local optimizers
@@ -1300,103 +1433,21 @@
       fixed_point -- scalar fixed-point finder
 
     """
-    _mintol = 1.0e-11
-    _cg = 0.3819660
-    if brack is None:
-        xa,xb,xc,fa,fb,fc,funcalls = bracket(func, args=args)
-    elif len(brack) == 2:
-        xa,xb,xc,fa,fb,fc,funcalls = bracket(func, xa=brack[0], xb=brack[1], args=args)
-    elif len(brack) == 3:
-        xa,xb,xc = brack
-        if (xa > xc):  # swap so xa < xc can be assumed
-            dum = xa; xa=xc; xc=dum
-        assert ((xa < xb) and (xb < xc)), "Not a bracketing interval."
-        fa = func(*((xa,)+args))
-        fb = func(*((xb,)+args))
-        fc = func(*((xc,)+args))
-        assert ((fb<fa) and (fb < fc)), "Not a bracketing interval."
-        funcalls = 3
-    else:
-        raise ValueError, "Bracketing interval must be length 2 or 3 sequence."
 
-    x=w=v=xb
-    fw=fv=fx=func(*((x,)+args))
-    if (xa < xc):
-        a = xa; b = xc
-    else:
-        a = xc; b = xa
-    deltax= 0.0
-    funcalls = 1
-    iter = 0
-    while (iter < maxiter):
-        tol1 = tol*abs(x) + _mintol
-        tol2 = 2.0*tol1
-        xmid = 0.5*(a+b)
-        if abs(x-xmid) < (tol2-0.5*(b-a)):  # check for convergence
-            xmin=x; fval=fx
-            break
-        if (abs(deltax) <= tol1):
-            if (x>=xmid): deltax=a-x       # do a golden section step
-            else: deltax=b-x
-            rat = _cg*deltax
-        else:                              # do a parabolic step
-            tmp1 = (x-w)*(fx-fv)
-            tmp2 = (x-v)*(fx-fw)
-            p = (x-v)*tmp2 - (x-w)*tmp1;
-            tmp2 = 2.0*(tmp2-tmp1)
-            if (tmp2 > 0.0): p = -p
-            tmp2 = abs(tmp2)
-            dx_temp = deltax
-            deltax= rat
-            # check parabolic fit
-            if ((p > tmp2*(a-x)) and (p < tmp2*(b-x)) and (abs(p) < abs(0.5*tmp2*dx_temp))):
-                rat = p*1.0/tmp2        # if parabolic step is useful.
-                u = x + rat
-                if ((u-a) < tol2 or (b-u) < tol2):
-                    if xmid-x >= 0: rat = tol1
-                    else: rat = -tol1
-            else:
-                if (x>=xmid): deltax=a-x # if it's not do a golden section step
-                else: deltax=b-x
-                rat = _cg*deltax
+    brent = Brent(func=func, args=args, tol=tol, full_output = full_output, maxiter=maxiter)
+    brent.set_bracket(brack)
+    brent.optimize()
+    return brent.get_result(full_output=full_output)
 
-        if (abs(rat) < tol1):            # update by at least tol1
-            if rat >= 0: u = x + tol1
-            else: u = x - tol1
-        else:
-            u = x + rat
-        fu = func(*((u,)+args))      # calculate new output value
-        funcalls += 1
 
-        if (fu > fx):                 # if it's bigger than current
-            if (u<x): a=u
-            else: b=u
-            if (fu<=fw) or (w==x):
-                v=w; w=u; fv=fw; fw=fu
-            elif (fu<=fv) or (v==x) or (v==w):
-                v=u; fv=fu
-        else:
-            if (u >= x): a = x
-            else: b = x
-            v=w; w=x; x=u
-            fv=fw; fw=fx; fx=fu
 
-        iter += 1
-
-    xmin = x
-    fval = fx
-    if full_output:
-        return xmin, fval, iter, funcalls
-    else:
-        return xmin
-
 def golden(func, args=(), brack=None, tol=_epsilon, full_output=0):
     """ Given a function of one-variable and a possible bracketing interval,
     return the minimum of the function isolated to a fractional precision of
     tol. A bracketing interval is a triple (a,b,c) where (a<b<c) and
-    func(b) < func(a),func(c).  If bracket is two numbers then they are
+    func(b) < func(a),func(c).  If bracket is two numbers (a, c) then they are
     assumed to be a starting interval for a downhill bracket search
-    (see bracket)
+    (see bracket); it doesn't always mean that obtained solution will satisfy a<=x<=c
 
     Uses analog of bisection method to decrease the bracketed interval.
 
@@ -1474,7 +1525,7 @@
     """Given a function and distinct initial points, search in the downhill
     direction (as defined by the initital points) and return new points
     xa, xb, xc that bracket the minimum of the function:
-    f(xa) > f(xb) < f(xc)
+    f(xa) > f(xb) < f(xc). It doesn't always mean that obtained solution will satisfy xa<=x<=xb
     """
     _gold = 1.618034
     _verysmall_num = 1e-21

Modified: trunk/Lib/optimize/tests/test_optimize.py
===================================================================
--- trunk/Lib/optimize/tests/test_optimize.py	2007-07-30 15:01:04 UTC (rev 3208)
+++ trunk/Lib/optimize/tests/test_optimize.py	2007-07-30 19:17:36 UTC (rev 3209)
@@ -126,6 +126,23 @@
         #print "LBFGSB: Difference is: " + str(err)
         assert err < 1e-6
 
+    def test_brent(self):
+        """ brent algorithm
+        """
+        x = optimize.brent(lambda x: (x-1.5)**2-0.8)
+        err1 = abs(x - 1.5)
+        x = optimize.brent(lambda x: (x-1.5)**2-0.8, brack = (-3,-2))
+        err2 = abs(x - 1.5)
+        x = optimize.brent(lambda x: (x-1.5)**2-0.8, full_output=True)
+        err3 = abs(x[0] - 1.5)
+        x = optimize.brent(lambda x: (x-1.5)**2-0.8, brack = (-15,-1,15))
+        err4 = abs(x - 1.5)
+        
+        assert max((err1,err2,err3,err4)) < 1e-6
+
+
+
+
 class test_tnc(NumpyTestCase):
     """TNC non-linear optimization.
 
@@ -225,5 +242,6 @@
         if ef > 1e-8:
             raise err
 
+
 if __name__ == "__main__":
     NumpyTest().run()



More information about the Scipy-svn mailing list