[Numpy-svn] r3098 - trunk/numpy/oldnumeric

numpy-svn at scipy.org numpy-svn at scipy.org
Tue Aug 29 13:29:58 CDT 2006


Author: oliphant
Date: 2006-08-29 13:29:53 -0500 (Tue, 29 Aug 2006)
New Revision: 3098

Added:
   trunk/numpy/oldnumeric/fix_default_axis.py
Log:
Add a module/script to fix the default axis issue for code already converted to NumPy

Added: trunk/numpy/oldnumeric/fix_default_axis.py
===================================================================
--- trunk/numpy/oldnumeric/fix_default_axis.py	2006-08-29 17:56:21 UTC (rev 3097)
+++ trunk/numpy/oldnumeric/fix_default_axis.py	2006-08-29 18:29:53 UTC (rev 3098)
@@ -0,0 +1,292 @@
+"""
+This module adds the default axis argument to code which did not specify it
+for the functions where the default was changed in NumPy.
+
+The functions changed are
+
+add -1  ( all second argument)
+======
+nansum
+nanmax
+nanmin
+nanargmax
+nanargmin
+argmax
+argmin
+compress 3
+
+
+add 0
+======
+take     3
+repeat   3
+sum         # might cause problems with builtin.
+product
+sometrue
+alltrue
+cumsum
+cumproduct
+average
+ptp
+cumprod
+prod
+std
+mean
+"""
+__all__ = ['convertfile', 'convertall', 'converttree',
+           'convertfile2','convertall2', 'converttree2']
+
+import sys
+import os
+import re
+import glob
+
+
+_args3 = ['compress', 'take', 'repeat']
+_funcm1 = ['nansum', 'nanmax', 'nanmin', 'nanargmax', 'nanargmin',
+           'argmax', 'argmin', 'compress']
+_func0 = ['take', 'repeat', 'sum', 'product', 'sometrue', 'alltrue',
+          'cumsum', 'cumproduct', 'average', 'ptp', 'cumprod', 'prod',
+          'std', 'mean']
+
+_all = _func0 + _funcm1
+func_re = {}
+
+for name in _all:
+    _astr = r"""%s\s*[(]"""%name 
+    func_re[name] = re.compile(_astr)
+
+
+import string
+disallowed = '_' + string.uppercase + string.lowercase + string.digits
+
+def _add_axis(fstr, name, repl):
+    alter = 0
+    if name in _args3:
+        allowed_comma = 1
+    else:
+        allowed_comma = 0
+    newcode = ""
+    last = 0
+    for obj in func_re[name].finditer(fstr):
+        nochange = 0
+        start, end = obj.span()
+        if fstr[start-1] in disallowed:
+            continue
+        if fstr[start-1] == '.' \
+           and fstr[start-6:start-1] != 'numpy' \
+           and fstr[start-2:start-1] != 'N' \
+           and fstr[start-9:start-1] != 'numarray' \
+           and fstr[start-8:start-1] != 'numerix' \
+           and fstr[start-8:start-1] != 'Numeric':
+            continue
+        if fstr[start-1] in ['\t',' ']:
+            k = start-2
+            while fstr[k] in ['\t',' ']:
+                k -= 1
+            if fstr[k-2:k+1] == 'def' or \
+               fstr[k-4:k+1] == 'class':
+                continue
+        k = end
+        stack = 1
+        ncommas = 0
+        N = len(fstr)
+        while stack:
+            if k>=N:
+                nochange =1
+                break
+            if fstr[k] == ')':
+                stack -= 1
+            elif fstr[k] == '(':
+                stack += 1
+            elif stack == 1 and fstr[k] == ',':
+                ncommas += 1
+                if ncommas > allowed_comma:
+                    nochange = 1
+                    break
+            k += 1
+        if nochange:
+            continue
+        alter += 1
+        newcode = "%s%s,%s)" % (newcode, fstr[last:k-1], repl)
+        last = k
+    if not alter:
+        newcode = fstr
+    else:
+        newcode = "%s%s" % (newcode, fstr[last:])
+    return newcode, alter
+
+def _import_change(fstr, names):
+    # Four possibilities
+    #  1.) import numpy with subsequent use of numpy.<name>
+    #        change this to import numpy.oldnumeric as numpy
+    #  2.) import numpy as XXXX with subsequent use of
+    #        XXXX.<name> ==> import numpy.oldnumeric as XXXX
+    #  3.) from numpy import *
+    #        with subsequent use of one of the names
+    #  4.) from numpy import ..., <name>, ... (could span multiple
+    #        lines.  ==> remove all names from list and
+    #        add from numpy.oldnumeric import <name>
+
+    num = 0
+    # case 1
+    importstr = "import numpy"
+    ind = fstr.find(importstr)
+    if (ind > 0):
+        found = 0
+        for name in names:
+            ind2 = fstr.find("numpy.%s" % name, ind)
+            if (ind2 > 0):
+                found = 1
+                break
+        if found:
+            fstr = "%s%s%s" % (fstr[:ind], "import numpy.oldnumeric as numpy",
+                               fstr[ind+len(importstr):])
+            num += 1
+            
+    # case 2
+    importre = re.compile("""import numpy as ([A-Za-z0-9_]+)""")
+    modules = importre.findall(fstr)
+    if len(modules) > 0:
+        for module in modules:
+            found = 0
+            for name in names:
+                ind2 = fstr.find("%s.%s" % (module, name))
+                if (ind2 > 0):
+                    found = 1
+                    break
+            if found:
+                importstr = "import numpy as %s" % module
+                ind = fstr.find(importstr)
+                fstr = "%s%s%s" % (fstr[:ind],
+                                   "import numpy.oldnumeric as %s" % module,
+                                   fstr[ind+len(importstr):])
+                num += 1
+
+    # case 3
+    importstr = "from numpy import *"
+    ind = fstr.find(importstr)
+    if (ind > 0):
+        found = 0
+        for name in names:
+            ind2 = fstr.find(name, ind)
+            if (ind2 > 0) and fstr[ind2-1] not in disallowed:
+                found = 1
+                break
+        if found:
+            fstr = "%s%s%s" % (fstr[:ind],
+                               "from numpy.oldnumeric import *",
+                               fstr[ind+len(importstr):])
+            num += 1
+
+    # case 4
+    ind = 0
+    importstr = "from numpy import"
+    N = len(importstr)
+    while 1:
+        ind = fstr.find(importstr, ind)
+        if (ind < 0):
+            break
+        ind += N
+        ptr = ind+1
+        stack = 1
+        while stack:
+            if fstr[ptr] == '\\':
+                stack += 1
+            elif fstr[ptr] == '\n':
+                stack -= 1
+            ptr += 1
+        substr = fstr[ind:ptr]
+        found = 0
+        substr = substr.replace('\n',' ')
+        substr = substr.replace('\\','')
+        importnames = [x.strip() for x in substr.split(',')]
+        # determine if any of names are in importnames
+        addnames = []
+        for name in names:
+            if name in importnames:
+                importnames.remove(name)
+                addnames.append(name)
+        if len(addnames) > 0:
+            fstr = "%s%s\n%s\n%s" % \
+                   (fstr[:ind],
+                    "from numpy import %s" % \
+                    ", ".join(importnames),
+                    "from numpy.oldnumeric import %s" % \
+                    ", ".join(addnames),
+                    fstr[ptr:])
+            num += 1
+
+    return fstr, num
+
+def add_axis(fstr, import_change=False):
+    total = 0
+    if not import_change:
+        for name in _funcm1:
+            fstr, num = _add_axis(fstr, name, 'axis=-1')
+            total += num
+        for name in _func0:
+            fstr, num = _add_axis(fstr, name, 'axis=0')
+            total += num
+        return fstr, total
+    else:
+        fstr, num = _import_change(fstr, _funcm1+_func0)
+        return fstr, num        
+
+
+def makenewfile(name, filestr):
+    fid = file(name, 'w')
+    fid.write(filestr)
+    fid.close()
+
+def getfile(name):
+    fid = file(name)
+    filestr = fid.read()
+    fid.close()
+    return filestr
+
+def copyfile(name, fstr):
+    base, ext = os.path.splitext(name)
+    makenewfile(base+'.orig', fstr)
+    return
+
+def convertfile(filename, import_change=False):
+    """Convert the filename given from using Numeric to using NumPy
+
+    Copies the file to filename.orig and then over-writes the file
+    with the updated code
+    """
+    filestr = getfile(filename)
+    newstr, total = add_axis(filestr, import_change)
+    if total > 0:
+        print "Changing ", filename
+        copyfile(filename, filestr)
+        makenewfile(filename, newstr)
+        sys.stdout.flush()
+
+def fromargs(args):
+    filename = args[1]
+    convertfile(filename)
+
+def convertall(direc=os.path.curdir, import_change=False):
+    """Convert all .py files in the directory given
+
+    For each file, a backup of <usesnumeric>.py is made as
+    <usesnumeric>.py.orig.  A new file named <usesnumeric>.py
+    is then written with the updated code.
+    """
+    files = glob.glob(os.path.join(direc,'*.py'))
+    for afile in files:
+        convertfile(afile, import_change)
+
+def _func(arg, dirname, fnames):
+    convertall(dirname, import_change=arg)
+
+def converttree(direc=os.path.curdir, import_change=False):
+    """Convert all .py files in the tree given
+
+    """
+    os.path.walk(direc, _func, import_change)
+
+if __name__ == '__main__':
+    fromargs(sys.argv)



More information about the Numpy-svn mailing list