[Numpy-svn] r2715 - trunk/numpy/f2py/lib

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Jun 30 16:45:23 CDT 2006


Author: pearu
Date: 2006-06-30 16:45:16 -0500 (Fri, 30 Jun 2006)
New Revision: 2715

Modified:
   trunk/numpy/f2py/lib/statements.py
   trunk/numpy/f2py/lib/test_parser.py
Log:
More unit-tests for Fortran parser.

Modified: trunk/numpy/f2py/lib/statements.py
===================================================================
--- trunk/numpy/f2py/lib/statements.py	2006-06-30 20:45:55 UTC (rev 2714)
+++ trunk/numpy/f2py/lib/statements.py	2006-06-30 21:45:16 UTC (rev 2715)
@@ -8,11 +8,17 @@
 
 is_name = re.compile(r'\w+\Z').match
 
-def split_comma(line, item):
+def split_comma(line, item = None, comma=','):
+    items = []
+    if item is None:
+        for s in line.split(comma):
+            s = s.strip()
+            if not s: continue
+            items.append(s)
+        return items
     newitem = item.copy(line, True)
     apply_map = newitem.apply_map
-    items = []
-    for s in newitem.get_line().split(','):
+    for s in newitem.get_line().split(comma):
         s = apply_map(s).strip()
         if not s: continue
         items.append(s)
@@ -37,14 +43,20 @@
     """
     def process_item(self):
         assert not self.item.has_map()
-        clsname = self.__class__.__name__.lower()
+        if hasattr(self,'stmtname'):
+            clsname = self.stmtname
+        else:
+            clsname = self.__class__.__name__.lower()
         line = self.item.get_line()[len(clsname):].lstrip()
         if line.startswith('::'):
             line = line[2:].lstrip()
         self.items = [s.strip() for s in line.split(',')]
         return
     def __str__(self):
-        clsname = self.__class__.__name__.upper()
+        if hasattr(self,'stmtname'):
+            clsname = self.stmtname.upper()
+        else:
+            clsname = self.__class__.__name__.upper()
         s = ', '.join(self.items)
         if s:
             s = ' ' + s
@@ -1174,28 +1186,68 @@
     def process_item(self):
         line = self.item.get_line()[6:].lstrip()
         i = line.index(')')
-        self.specs = line[1:i].strip()
+
+        line0 = line[1:i]
         line = line[i+1:].lstrip()
-        stmt = GeneralAssignment(self, self.item.copy(line))
+        stmt = GeneralAssignment(self, self.item.copy(line, True))
         if stmt.isvalid:
             self.content = [stmt]
         else:
             self.isvalid = False
+            return
+
+        specs = []
+        mask = ''
+        for l in split_comma(line0,self.item):
+            j = l.find('=')
+            if j==-1:
+                assert not mask,`mask,l`
+                mask = l
+                continue
+            assert j!=-1,`l`
+            index = l[:j].rstrip()
+            it = self.item.copy(l[j+1:].lstrip())
+            l = it.get_line()
+            k = l.split(':')
+            if len(k)==3:
+                s1, s2, s3 = map(it.apply_map,
+                                 [k[0].strip(),k[1].strip(),k[2].strip()])
+            else:
+                assert len(k)==2,`k`
+                s1, s2 = map(it.apply_map,
+                             [k[0].strip(),k[1].strip()])
+                s3 = '1'
+            specs.append((index,s1,s2,s3))
+
+        self.specs = specs
+        self.mask = mask
         return
+
     def __str__(self):
         tab = self.get_indent_tab()
-        return tab + 'FORALL (%s) %s' % (self.specs, str(self.content[0]).lstrip())
+        l = []
+        for index,s1,s2,s3 in self.specs:
+            s = '%s = %s : %s' % (index,s1,s2)
+            if s3!='1':
+                s += ' : %s' % (s3)
+            l.append(s)
+        s = ', '.join(l)
+        if self.mask:
+            s += ', ' + self.mask
+        return tab + 'FORALL (%s) %s' % \
+               (s, str(self.content[0]).lstrip())
 
 ForallStmt = Forall
 
 class SpecificBinding(Statement):
     """
-    PROCEDURE [ (<interface-name>) ]  [ [ , <binding-attr-list> ] :: ] <binding-name> [ => <procedure-name> ]
+    PROCEDURE [ ( <interface-name> ) ]  [ [ , <binding-attr-list> ] :: ] <binding-name> [ => <procedure-name> ]
     <binding-attr> = PASS [ ( <arg-name> ) ]
                    | NOPASS
                    | NON_OVERRIDABLE
                    | DEFERRED
                    | <access-spec>
+    <access-spec> = PUBLIC | PRIVATE
     """
     match = re.compile(r'procedure\b',re.I).match
     def process_item(self):
@@ -1206,46 +1258,69 @@
             line = line[i+1:].lstrip()
         else:
             name = ''
-        self.interface_name = name
+        self.iname = name
         if line.startswith(','):
             line = line[1:].lstrip()
         i = line.find('::')
         if i != -1:
-            attrs = line[:i].rstrip()
+            attrs = split_comma(line[:i], self.item)
             line = line[i+2:].lstrip()
         else:
-            attrs = ''
-        self.attrs = attrs
-        self.rest = line
+            attrs = []
+        attrs1 = []
+        for attr in attrs:
+            if is_name(attr):
+                attr = attr.upper()
+            else:
+                i = attr.find('(')
+                assert i!=-1 and attr.endswith(')'),`attr`
+                attr = '%s (%s)' % (attr[:i].rstrip().upper(), attr[i+1:-1].strip())
+            attrs1.append(attr)
+        self.attrs = attrs1
+        i = line.find('=')
+        if i==-1:
+            self.name = line
+            self.bname = ''
+        else:
+            self.name = line[:i].rstrip()
+            self.bname = line[i+1:].lstrip()[1:].lstrip()
         return
     def __str__(self):
         tab = self.get_indent_tab()
         s = 'PROCEDURE '
-        if self.interface_name:
-            s += ' (' + self.interface_name + ')'
+        if self.iname:
+            s += '(' + self.iname + ') '
         if self.attrs:
-            s += ' , ' + self.attrs + ' :: '
-        return tab + s + rest
+            s += ', ' + ', '.join(self.attrs) + ' :: '
+        if self.bname:
+            s += '%s => %s' % (self.name, self.bname)
+        else:
+            s += self.name
+        return tab + s
 
 class GenericBinding(Statement):
     """
     GENERIC [ , <access-spec> ] :: <generic-spec> => <binding-name-list>
     """
-    match = re.compile(r'generic\b.*::.*=.*\Z', re.I).match
+    match = re.compile(r'generic\b.*::.*=\>.*\Z', re.I).match
     def process_item(self):
         line = self.item.get_line()[7:].lstrip()
         if line.startswith(','):
             line = line[1:].lstrip()
         i = line.index('::')
-        self.specs = line[:i].lstrip()
-        self.rest = line[i+2:].lstrip()
+        self.aspec = line[:i].rstrip().upper()
+        line = line[i+2:].lstrip()
+        i = line.index('=>')
+        self.spec = self.item.apply_map(line[:i].rstrip())
+        self.items = split_comma(line[i+2:].lstrip())
         return
+
     def __str__(self):
         tab = self.get_indent_tab()
         s = 'GENERIC'
-        if self.specs:
-            s += ', '+self.specs
-        s += ' :: ' + self.rest
+        if self.aspec:
+            s += ', '+self.aspec
+        s += ' :: ' + self.spec + ' => ' + ', '.join(self.items)
         return tab + s
 
         
@@ -1253,6 +1328,7 @@
     """
     FINAL [ :: ] <final-subroutine-name-list>
     """
+    stmtname = 'final'
     match = re.compile(r'final\b', re.I).match
 
 class Allocatable(Statement):
@@ -1264,14 +1340,10 @@
         line = self.item.get_line()[11:].lstrip()
         if line.startswith('::'):
             line = line[2:].lstrip()
-        items = []
-        for s in line.split(','):
-            s = s.strip()
-            items.append(s)
-        self.items = items
+        self.items = split_comma(line, self.item)
         return
     def __str__(self):
-        return self.get_tab_indent() + 'ALLOCATABLE ' + ', '.join(self.items) 
+        return self.get_indent_tab() + 'ALLOCATABLE ' + ', '.join(self.items) 
 
 class Asynchronous(StatementWithNamelist):
     """
@@ -1285,12 +1357,30 @@
     <language-binding-spec> = BIND ( C [ , NAME = <scalar-char-initialization-expr> ] )
     <bind-entity> = <entity-name> | / <common-block-name> /
     """
-    match = re.compile(r'bind\s*\(.*\)\Z',re.I).match
+    match = re.compile(r'bind\s*\(.*\)',re.I).match
     def process_item(self):
-        self.value = self.item.get_line()[4].lstrip()[1:-1].strip()
+        line = self.item.get_line()[4:].lstrip()
+        specs = []
+        for spec in specs_split_comma(line[1:line.index(')')].strip(), self.item):
+            if is_name(spec):
+                specs.append(spec.upper())
+            else:
+                specs.append(spec)            
+        self.specs = specs
+        line = line[line.index(')')+1:].lstrip()
+        if line.startswith('::'):
+            line = line[2:].lstrip()
+        items = []
+        for item in split_comma(line, self.item):
+            if item.startswith('/'):
+                assert item.endswith('/'),`item`
+                item = '/ ' + item[1:-1].strip() + ' /'
+            items.append(item)
+        self.items = items
         return
     def __str__(self):
-        return self.get_indent_tab() + 'BIND (%s)' % (self.value)
+        return self.get_indent_tab() + 'BIND (%s) %s' %\
+               (', '.join(self.specs), ', '.join(self.items))
 
 # IF construct statements
 
@@ -1303,22 +1393,25 @@
     def process_item(self):
         item = self.item
         self.name = item.get_line()[4:].strip()
-        if self.name and not self.name==self.parent.name:
+        parent_name = getattr(self.parent,'name','')
+        if self.name and self.name!=parent_name:
             message = self.reader.format_message(\
                         'WARNING',
                         'expected if-construct-name %r but got %r, skipping.'\
-                        % (self.parent.name, self.name),
+                        % (parent_name, self.name),
                         item.span[0],item.span[1])
             print >> sys.stderr, message
             self.isvalid = False        
         return
 
     def __str__(self):
-        return self.get_indent_tab(deindent=True) + 'ELSE ' + self.name
+        if self.name:
+            return self.get_indent_tab(deindent=True) + 'ELSE ' + self.name
+        return self.get_indent_tab(deindent=True) + 'ELSE'
 
 class ElseIf(Statement):
     """
-    ELSE IF ( <scalar-logical-expr> ) THEN [<if-construct-name>]
+    ELSE IF ( <scalar-logical-expr> ) THEN [ <if-construct-name> ]
     """
     match = re.compile(r'else\s*if\s*\(.*\)\s*then\s*\w*\s*\Z',re.I).match
 
@@ -1327,21 +1420,25 @@
         line = item.get_line()[4:].lstrip()[2:].lstrip()
         i = line.find(')')
         assert line[0]=='('
-        self.expr = line[1:i]
+        self.expr = item.apply_map(line[1:i])
         self.name = line[i+1:].lstrip()[4:].strip()
-        if self.name and not self.name==self.parent.name:
+        parent_name = getattr(self.parent,'name','')
+        if self.name and self.name!=parent_name:
             message = self.reader.format_message(\
                         'WARNING',
                         'expected if-construct-name %r but got %r, skipping.'\
-                        % (self.parent.name, self.name),
+                        % (parent_name, self.name),
                         item.span[0],item.span[1])
-            print >> sys.stderr, message
+            self.show_message(message)
             self.isvalid = False        
         return
         
     def __str__(self):
-        return self.get_indent_tab(deindent=True) + 'ELSE IF (%s) THEN %s' \
-               % (self.expr, self.name)
+        s = ''
+        if self.name:
+            s = ' ' + self.name
+        return self.get_indent_tab(deindent=True) + 'ELSE IF (%s) THEN%s' \
+               % (self.expr, s)
 
 # SelectCase construct statements
 
@@ -1357,27 +1454,49 @@
     """
     match = re.compile(r'case\b\s*(\(.*\)|DEFAULT)\s*\w*\Z',re.I).match
     def process_item(self):
-        assert self.parent.__class__.__name__=='Select',`self.parent.__class__`
+        #assert self.parent.__class__.__name__=='Select',`self.parent.__class__`
         line = self.item.get_line()[4:].lstrip()
         if line.startswith('('):
             i = line.find(')')
-            self.ranges = line[1:i].strip()
+            items = split_comma(line[1:i].strip(), self.item)
             line = line[i+1:].lstrip()
         else:
             assert line.startswith('default'),`line`
-            self.ranges = ''
+            items = []
             line = line[7:].lstrip()
+        for i in range(len(items)):
+            it = self.item.copy(items[i])
+            rl = []
+            for r in it.get_line().split(':'):
+                rl.append(it.apply_map(r.strip()))
+            items[i] = rl
+        self.items = items
         self.name = line
-        if self.name and not self.name==self.parent.name:
+        parent_name = getattr(self.parent, 'name', '')
+        if self.name and self.name!=parent_name:
             message = self.reader.format_message(\
                         'WARNING',
                         'expected case-construct-name %r but got %r, skipping.'\
-                        % (self.parent.name, self.name),
+                        % (parent_name, self.name),
                         self.item.span[0],self.item.span[1])
-            print >> sys.stderr, message
+            self.show_message(message)
             self.isvalid = False        
         return
 
+    def __str__(self):
+        tab = self.get_indent_tab()
+        s = 'CASE'
+        if self.items:
+            l = []
+            for item in self.items:
+                l.append((' : '.join(item)).strip())
+            s += ' ( %s )' % (', '.join(l))
+        else:
+            s += ' DEFAULT'
+        if self.name:
+            s += ' ' + self.name
+        return s
+
 # Where construct statements
 
 class Where(Statement):

Modified: trunk/numpy/f2py/lib/test_parser.py
===================================================================
--- trunk/numpy/f2py/lib/test_parser.py	2006-06-30 20:45:55 UTC (rev 2714)
+++ trunk/numpy/f2py/lib/test_parser.py	2006-06-30 21:45:16 UTC (rev 2715)
@@ -3,18 +3,15 @@
 from block_statements import *
 from readfortran import Line, FortranStringReader
 
-def toLine(line, label=''):
+
+def parse(cls, line, label=''):
     if label:
         line = label + ' : ' + line
     reader = FortranStringReader(line, True, False)
-    return reader.next()
-
-def parse(cls, line, label=''):
-    item = toLine(line, label=label)
+    item = reader.next()
     if not cls.match(item.get_line()):
         raise ValueError, '%r does not match %s pattern' % (line, cls.__name__)
     stmt = cls(item, item)
-
     if stmt.isvalid:
         return str(stmt)
     raise ValueError, 'parsing %r with %s pattern failed' % (line, cls.__name__)
@@ -306,5 +303,74 @@
         assert_equal(parse(Import,'import::a'),'IMPORT a')
         assert_equal(parse(Import,'import a , b'),'IMPORT a, b')
 
+    def check_forall(self):
+        assert_equal(parse(ForallStmt,'forall (i = 1:n(k,:) : 2) a(i) = i*i*b(i)'),
+                     'FORALL (i = 1 : n(k,:) : 2) a(i) = i*i*b(i)')
+        assert_equal(parse(ForallStmt,'forall (i=1:n,j=2:3) a(i) = b(i,i)'),
+                     'FORALL (i = 1 : n, j = 2 : 3) a(i) = b(i,i)')
+        assert_equal(parse(ForallStmt,'forall (i=1:n,j=2:3, 1+a(1,2)) a(i) = b(i,i)'),
+                     'FORALL (i = 1 : n, j = 2 : 3, 1+a(1,2)) a(i) = b(i,i)')
+
+    def check_specificbinding(self):
+        assert_equal(parse(SpecificBinding,'procedure a'),'PROCEDURE a')
+        assert_equal(parse(SpecificBinding,'procedure :: a'),'PROCEDURE a')
+        assert_equal(parse(SpecificBinding,'procedure , NOPASS :: a'),'PROCEDURE , NOPASS :: a')
+        assert_equal(parse(SpecificBinding,'procedure , public, pass(x ) :: a'),'PROCEDURE , PUBLIC, PASS (x) :: a')
+        assert_equal(parse(SpecificBinding,'procedure(n) a'),'PROCEDURE (n) a')
+        assert_equal(parse(SpecificBinding,'procedure(n),pass :: a'),
+                     'PROCEDURE (n) , PASS :: a')
+        assert_equal(parse(SpecificBinding,'procedure(n) :: a'),
+                     'PROCEDURE (n) a')
+        assert_equal(parse(SpecificBinding,'procedure a= >b'),'PROCEDURE a => b')
+        assert_equal(parse(SpecificBinding,'procedure(n),pass :: a =>c'),
+                     'PROCEDURE (n) , PASS :: a => c')
+
+    def check_genericbinding(self):
+        assert_equal(parse(GenericBinding,'generic :: a=>b'),'GENERIC :: a => b')
+        assert_equal(parse(GenericBinding,'generic, public :: a=>b'),'GENERIC, PUBLIC :: a => b')
+        assert_equal(parse(GenericBinding,'generic, public :: a(1,2)=>b ,c'),
+                     'GENERIC, PUBLIC :: a(1,2) => b, c')
+
+    def check_finalbinding(self):
+        assert_equal(parse(FinalBinding,'final a'),'FINAL a')
+        assert_equal(parse(FinalBinding,'final::a'),'FINAL a')
+        assert_equal(parse(FinalBinding,'final a , b'),'FINAL a, b')
+
+    def check_allocatable(self):
+        assert_equal(parse(Allocatable,'allocatable a'),'ALLOCATABLE a')
+        assert_equal(parse(Allocatable,'allocatable :: a'),'ALLOCATABLE a')
+        assert_equal(parse(Allocatable,'allocatable a (1,2)'),'ALLOCATABLE a (1,2)')
+        assert_equal(parse(Allocatable,'allocatable a (1,2) ,b'),'ALLOCATABLE a (1,2), b')
+
+    def check_asynchronous(self):
+        assert_equal(parse(Asynchronous,'asynchronous a'),'ASYNCHRONOUS a')
+        assert_equal(parse(Asynchronous,'asynchronous::a'),'ASYNCHRONOUS a')
+        assert_equal(parse(Asynchronous,'asynchronous a , b'),'ASYNCHRONOUS a, b')
+
+    def check_bind(self):
+        assert_equal(parse(Bind,'bind(c) a'),'BIND (C) a')
+        assert_equal(parse(Bind,'bind(c) :: a'),'BIND (C) a')
+        assert_equal(parse(Bind,'bind(c) a ,b'),'BIND (C) a, b')
+        assert_equal(parse(Bind,'bind(c) /a/'),'BIND (C) / a /')
+        assert_equal(parse(Bind,'bind(c) /a/ ,b'),'BIND (C) / a /, b')
+        assert_equal(parse(Bind,'bind(c,name="hey") a'),'BIND (C, NAME = "hey") a')
+
+    def check_else(self):
+        assert_equal(parse(Else,'else'),'ELSE')
+        assert_equal(parse(ElseIf,'else if (a) then'),'ELSE IF (a) THEN')
+        assert_equal(parse(ElseIf,'else if (a.eq.b(1,2)) then'),
+                     'ELSE IF (a.eq.b(1,2)) THEN')
+
+    def check_case(self):
+        assert_equal(parse(Case,'case (1)'),'CASE ( 1 )')
+        assert_equal(parse(Case,'case (1:)'),'CASE ( 1 : )')
+        assert_equal(parse(Case,'case (:1)'),'CASE ( : 1 )')
+        assert_equal(parse(Case,'case (1:2)'),'CASE ( 1 : 2 )')
+        assert_equal(parse(Case,'case (a(1,2))'),'CASE ( a(1,2) )')
+        assert_equal(parse(Case,'case ("ab")'),'CASE ( "ab" )')
+        assert_equal(parse(Case,'case default'),'CASE DEFAULT')
+        assert_equal(parse(Case,'case (1:2 ,3:4)'),'CASE ( 1 : 2, 3 : 4 )')
+        assert_equal(parse(Case,'case (a(1,:):)'),'CASE ( a(1,:) : )')
+    
 if __name__ == "__main__":
     NumpyTest().run()



More information about the Numpy-svn mailing list