[Scipy-svn] r4490 - trunk/scipy/sandbox/mkufunc

scipy-svn@scip... scipy-svn@scip...
Sat Jun 28 09:36:38 CDT 2008


Author: ilan
Date: 2008-06-28 09:36:29 -0500 (Sat, 28 Jun 2008)
New Revision: 4490

Modified:
   trunk/scipy/sandbox/mkufunc/mkufunc.py
   trunk/scipy/sandbox/mkufunc/test_mkufunc.py
Log:
Improved dispatch on type in mkufunc and added more tests

Modified: trunk/scipy/sandbox/mkufunc/mkufunc.py
===================================================================
--- trunk/scipy/sandbox/mkufunc/mkufunc.py	2008-06-28 13:54:11 UTC (rev 4489)
+++ trunk/scipy/sandbox/mkufunc/mkufunc.py	2008-06-28 14:36:29 UTC (rev 4490)
@@ -15,7 +15,6 @@
 
 
 verbose = False
-_cnt = 0
 
 typedict = {
     int:    ['NPY_LONG',   'long'  ],
@@ -41,7 +40,7 @@
     [<type 'int'>, <type 'int'>]
 
     Attributes:
-
+        f           -- the Python function object
         n           -- id number
         sig         -- signature
         nin         -- number of input arguments
@@ -56,10 +55,9 @@
                     -- generate the C support code to make this
                        function part work with PyUFuncGenericFunction
     """
-    def __init__(self, f, signature):
-        global _cnt
-        _cnt += 1
-        self.n = _cnt
+    def __init__(self, f, signature, n=0):
+        self.f = f
+        self.n = n
         self.sig = signature
         self.nin = f.func_code.co_argcount     # input args
         self.nout = len(self.sig) - self.nin
@@ -83,7 +81,7 @@
         self._prefix = 'f%i_' % self.n
         self._allCsrc = src.replace('pypy_', self._prefix + 'pypy_')
         self.cname = self._prefix + 'pypy_g_' + f.__name__
-        
+
     def cfunc(self):
         p = re.compile(r'^\w+[*\s\w]+' + self.cname +
                        r'\s*\([^)]*\)\s*\{.*?[\n\r]\}[\n\r]',
@@ -110,7 +108,7 @@
         n = self.n
         cname = self.cname
         return '''
-static %(rettype)s foo_%(n)i(%(arg0type)s x)
+static %(rettype)s wrap_%(cname)s(%(arg0type)s x)
 {
 	return %(cname)s(x);
 }
@@ -120,8 +118,6 @@
 static void
 PyUFunc_%(n)i(char **args, npy_intp *dimensions, npy_intp *steps, void *func)
 {
-	/* printf("PyUFunc_%(n)i\\n"); */
-	
 	npy_intp n = dimensions[0];
 	npy_intp is0 = steps[0];
 	npy_intp os = steps[1];
@@ -179,7 +175,7 @@
     """
     signatures.sort(key=lambda sig: [numpy.dtype(typ).num for typ in sig])
     
-    cfuncs = [Cfunc(f, sig) for sig in signatures]
+    cfuncs = [Cfunc(f, sig, n) for n, sig in enumerate(signatures)]
     
     write_pypyc(cfuncs)
     
@@ -189,11 +185,13 @@
 
     pyufuncs = ''.join('\tPyUFunc_%i,\n' % cf.n for cf in cfuncs)
     
-    data = ''.join('\t(void *) foo_%i,\n' % cf.n for cf in cfuncs)
+    data = ''.join('\t(void *) wrap_%s,\n' % cf.cname for cf in cfuncs)
+    
+    types = ''.join('\t%s  /* %i */\n' %
+                    (''.join(typedict[t][0] + ', ' for t in cf.sig), cf.n)
+                    for cf in cfuncs)
 
-    foo_signatures = ''.join('\t%s  /* %i */\n' %
-                         (''.join(typedict[t][0] + ', ' for t in cf.sig), cf.n)
-                         for cf in cfuncs)
+    fname = f.__name__
     
     support_code = '''
 extern "C" {
@@ -201,14 +199,14 @@
 
 %(func_support)s
 
-static PyUFuncGenericFunction foo_functions[] = {
+static PyUFuncGenericFunction %(fname)s_functions[] = {
 %(pyufuncs)s};
 
-static void *foo_data[] = {
+static void *%(fname)s_data[] = {
 %(data)s};
 
-static char foo_signatures[] = {
-%(foo_signatures)s};
+static char %(fname)s_types[] = {
+%(types)s};
 ''' % locals()
 
     ntypes = len(signatures)
@@ -218,14 +216,14 @@
 import_ufunc();
 
 return_val = PyUFunc_FromFuncAndData(
-    foo_functions,
-    foo_data,
-    foo_signatures,
+    %(fname)s_functions,
+    %(fname)s_data,
+    %(fname)s_types,
     %(ntypes)i,         /* ntypes */
     %(nin)i,            /* nin */
     1,                  /* nout */
     PyUFunc_None,       /* identity */
-    "foo",              /* name */
+    "%(fname)s",        /* name */
     "",                 /* doc */
     0);
 ''' % locals()
@@ -272,20 +270,40 @@
     print "y =", y, y.dtype
 
 
-def mkufunc(arg0):
+def mkufunc(arg0=[float]):
     """ The actual API function, to be used as decorator function.
         
     """
     class Compile(object):
         
         def __init__(self, f):
+            nin = f.func_code.co_argcount
+            nout = 1
+            for i, sig in enumerate(signatures):
+                if sig in typedict.keys():
+                    signatures[i] = (nin + nout) * (sig,)
+                elif isinstance(sig, tuple):
+                    pass
+                else:
+                    raise TypeError
+
+            for sig in signatures:
+                assert isinstance(sig, tuple)
+                if len(sig) != nin + nout:
+                    raise TypeError("signature %r does not match the "
+                                    "number of args of function %s" %
+                                    (sig, f.__name__))
+                for t in sig:
+                    if t not in typedict.keys():
+                        raise TypeError
+            
             print 'sigs:', signatures
             self.ufunc = genufunc(f, signatures)
             #self.ufunc = f
-
+            
         def __call__(self, *args):
             return self.ufunc(*args)
-
+        
     if isinstance(arg0, FunctionType):
         f = arg0
         signatures = [float]
@@ -294,11 +312,11 @@
     elif isinstance(arg0, ListType):
         signatures = arg0
         return Compile
-
-    elif arg0 in typedict:
+    
+    elif arg0 in typedict.keys():
         signatures = [arg0]
         return Compile
-
+    
     else:
         raise TypeError("first argument has to be a function, a type, or "
                         "a list of signatures")
@@ -307,13 +325,18 @@
 if __name__ == '__main__':
     import doctest
     #doctest.testmod()
+    
+    test2()
 
+    exit()
+    
+
     def sqr(x):
         return x * x
-
+    
     #sqr = mkufunc({})(sqr)
     sqr = mkufunc([(float, float)])(sqr)
     #sqr = mkufunc(int)(sqr)
     #sqr = mkufunc(sqr)
-
+    
     print sqr(8)

Modified: trunk/scipy/sandbox/mkufunc/test_mkufunc.py
===================================================================
--- trunk/scipy/sandbox/mkufunc/test_mkufunc.py	2008-06-28 13:54:11 UTC (rev 4489)
+++ trunk/scipy/sandbox/mkufunc/test_mkufunc.py	2008-06-28 14:36:29 UTC (rev 4490)
@@ -1,22 +1,85 @@
 import math
-from math import sin, cos, pi
 import unittest
 
-from numpy import array, arange, allclose
+from numpy import array, arange, allclose, vectorize
 
 from mkufunc import Cfunc, genufunc, mkufunc
 
+def f(x):
+    return 3.2 * x * x - 18.3 * x + sin(x)
 
+class Arg_Tests(unittest.TestCase):
+    
+    def check_ufunc(self, f):
+        #self.assert_(f.__type__
+        for arg in (array([0.0, 1.0, 2.5]),
+                    [0.0, 1.0, 2.5],
+                    (0.0, 1.0, 2.5)):
+            self.assert_(allclose(f(arg), [0.0, 1.0, 6.25]))
+            
+        self.assertEqual(f(3), 9)
+        self.assert_(f(-2.5) - 6.25 < 1E-10)
+
+    def test_direct(self):
+        @mkufunc
+        def f(x):
+            return x * x
+        self.check_ufunc(f)
+        
+    def test_noargs(self):
+        @mkufunc()
+        def f(x):
+            return x * x
+        self.check_ufunc(f)
+        
+    def test_varargs(self):
+        for arg in (float,
+                    [float],
+                    [(float, float)]):
+            @mkufunc(arg)
+            def f(x):
+                return x * x
+            self.check_ufunc(f)
+
+
 class Math_Tests(unittest.TestCase):
     
-    def test_sin(self):
-        @mkufunc([(float, float)])
-        def u_sin(x):
-            return sin(x)
+    def test_func1arg(self):
+        for f in (math.exp, math.log, math.sqrt,
+                  math.acos, math.asin, math.atan,
+                  math.cos, math.sin, math.tan):
+            @mkufunc
+            def uf(x):
+                return f(x)
+            x = 0.4376
+            a = uf(x)
+            b = f(x)
+            self.assert_(abs(a - b) < 1E-10, '%r %s != %s' % (f, a, b))
+            xx = arange(0.1, 0.9, 0.01)
+            a = uf(xx)
+            b = [f(x) for x in xx]
+            self.assert_(allclose(a, b))
 
-        x = 1.23
-        self.assert_(u_sin(x), sin(x))
+    def test_arithmetic(self):
+        def f(x):
+            return (4 * x + 2) / (x * x - 7 * x + 1)
+        uf = mkufunc(f)
+        x = arange(0, 2, 0.1)
+        self.assert_(allclose(uf(x), f(x)))
+    
 
+class Loop_Tests(unittest.TestCase):
+    pass
 
+class Switch_Tests(unittest.TestCase):
+    pass
+
+class FreeVariable_Tests(unittest.TestCase):
+    pass
+
+class Misc_Tests(unittest.TestCase):
+    pass
+
+
 if __name__ == '__main__':
     unittest.main()



More information about the Scipy-svn mailing list