[Scipy-svn] r5059 - in trunk/scipy/io/matlab: . tests

scipy-svn@scip... scipy-svn@scip...
Tue Nov 11 03:03:29 CST 2008


Author: matthew.brett@gmail.com
Date: 2008-11-11 03:03:27 -0600 (Tue, 11 Nov 2008)
New Revision: 5059

Modified:
   trunk/scipy/io/matlab/mio5.py
   trunk/scipy/io/matlab/tests/test_mio.py
Log:
Further checks on input types for writing, to fix recursion write error.  Moved function as well as object to own array class

Modified: trunk/scipy/io/matlab/mio5.py
===================================================================
--- trunk/scipy/io/matlab/mio5.py	2008-11-11 07:32:57 UTC (rev 5058)
+++ trunk/scipy/io/matlab/mio5.py	2008-11-11 09:03:27 UTC (rev 5059)
@@ -173,91 +173,28 @@
     pass
 
 
-class MatlabObject(object):
-    ''' Class to contain data read from matlab objects 
+class MatlabFunction(np.ndarray):
+    ''' class to signal this is a matlab function '''
+    def __new__(cls, input_array):
+        return np.asarray(input_array).view(cls)
 
-    Contains classname, and record array for field names and values
 
-    Attribute access fetches and sets record array fields if present
+class MatlabObject(np.ndarray):
+    def __new__(cls, input_array, classname=None):
+        # Input array is an already formed ndarray instance
+        # We first cast to be our class type
+        obj = np.asarray(input_array).view(cls)
+        # add the new attribute to the created instance
+        obj.classname = classname
+        # Finally, we must return the newly created object:
+        return obj
 
-    '''
-    def __init__(self, classname, fields):
-        """ Initialize MatlabObject
+    def __array_finalize__(self,obj):
+        # reset the attribute from passed original object
+        self.classname = getattr(obj, 'classname', None)
+        # We do not need to return anything
 
-        Parameters
-        ----------
-	self : object
-	classname : string
-	    class name for matlab object
-	fields : {recarray, string list}
-            either a recarray or a list of field names
 
-        >>> import numpy as np
-        >>> arr = np.zeros((1,1),dtype=[('field1','i2'),('field2','i2')])
-        >>> obj = MatlabObject('myclass', arr)
-        >>> obj = MatlabObject('myclass', ['field1', 'field2'])
-
-        """
-        # Initialize to make field setting work with __setattr__
-        self.__dict__['_fields'] = []
-        self.classname = classname
-        try: # recarray
-            fdict = fields.dtype.fields
-        except AttributeError: # something else
-            fields = tuple(fields)
-        else: # recarray again
-            self._fields = fdict.keys()
-            self.mobj_recarray = fields
-            return
-        # something else again
-        self._fields = fields
-        dtype = [(field, object) for field in fields]
-        self.mobj_recarray = np.zeros((1,1), dtype)
-    
-    def __getattr__(self, name):
-        ''' get attributes from object
-
-        Get attribute if present, otherwise field from recarray
-        
-        >>> import numpy as np
-        >>> arr = np.zeros((1,1),dtype=[('field1','i2'),('field2','i2')])
-        >>> obj = MatlabObject('myclass', arr)
-        >>> obj.field1
-        array([[0]], dtype=int16)
-        >>> obj = MatlabObject('myclass', ['field1', 'field2'])
-        >>> obj.field1
-        array([[0]], dtype=object)
-        >>> obj.classname
-        'myclass'
-        '''
-        if name in self.__dict__:
-            return self.__dict__[name]
-        mobj_recarray = self.__dict__['mobj_recarray']
-        if name in self.__dict__['_fields']:
-            return mobj_recarray[name]
-        else:
-            raise AttributeError(
-                "no field named %s in MatlabObject" % name)
-
-    def __setattr__(self, name, value):
-        ''' set attributes in object
-
-        Set field value from recarray, if present, else attribute
-
-        >>> import numpy as np
-        >>> arr = np.zeros((1,1),dtype=[('field1','i2'),('field2','i2')])
-        >>> obj = MatlabObject('myclass', arr)
-        >>> obj.field1[0,0] = 1
-        >>> obj.strangename = 'test'
-        >>> obj.strangename
-        'test'
-        '''
-        if name in self._fields:
-            self.mobj_recarray[name] = value
-        else:
-            self.__dict__[name] = value
-
-
 class Mat5ArrayReader(MatArrayReader):
     ''' Class to get Mat5 arrays
 
@@ -265,7 +202,13 @@
     factory function
     '''
 
-    def __init__(self, mat_stream, dtypes, processor_func, codecs, class_dtypes, struct_as_record):
+    def __init__(self,
+                 mat_stream,
+                 dtypes,
+                 processor_func,
+                 codecs,
+                 class_dtypes,
+                 struct_as_record):
         super(Mat5ArrayReader, self).__init__(mat_stream,
                                               dtypes,
                                               processor_func)
@@ -508,6 +451,7 @@
     def get_item(self):
         return self.read_element()
 
+
 class Mat5StructMatrixGetter(Mat5MatrixGetter):
     def __init__(self, array_reader, header):
         super(Mat5StructMatrixGetter, self).__init__(array_reader, header)
@@ -535,34 +479,21 @@
                 for name in field_names:
                     item.__dict__[name] = self.read_element()
                 result[i] = item
-
         return result.reshape(tupdims).T
 
 
-class Mat5ObjectMatrixGetter(Mat5MatrixGetter):
-    def get_array(self):
+class Mat5ObjectMatrixGetter(Mat5StructMatrixGetter):
+    def get_raw_array(self):
         '''Matlab ojects are essentially structs, with an extra field, the classname.'''
         classname = self.read_element().tostring()
-        namelength = self.read_element()[0]
-        names = self.read_element()
-        field_names = [names[i:i+namelength].tostring().strip('\x00')
-                       for i in xrange(0,len(names),namelength)]
-        result = MatlabObject(classname, field_names)
+        result = super(Mat5ObjectMatrixGetter, self).get_raw_array()
+        return MatlabObject(result, classname)
 
-        for field_name in field_names:
-            result.__setattr__(field_name, self.read_element())
 
-        return result
-
-
-class MatlabFunctionMatrix:
-    ''' Opaque object representing an array of function handles. '''
-    def __init__(self, arr):
-        self.arr = arr
-
 class Mat5FunctionMatrixGetter(Mat5CellMatrixGetter):
-    def get_array(self):
-        return MatlabFunctionMatrix(self.get_raw_array())
+    def get_raw_array(self):
+        result = super(Mat5ObjectMatrixGetter, self).get_raw_array()
+        return MatlabFunction(result)
 
 
 class MatFile5Reader(MatFileReader):
@@ -688,10 +619,16 @@
     mat_tag = np.zeros((), mdtypes_template['tag_full'])
     mat_tag['mdtype'] = miMATRIX
 
-    def __init__(self, file_stream, arr, name, is_global=False):
+    def __init__(self,
+                 file_stream,
+                 arr,
+                 name,
+                 is_global=False,
+                 unicode_strings=False):
         super(Mat5MatrixWriter, self).__init__(file_stream, arr, name)
         self.is_global = is_global
-
+        self.unicode_strings = unicode_strings
+        
     def write_dtype(self, arr):
         self.file_stream.write(arr.tostring())
 
@@ -838,13 +775,7 @@
         self.update_matrix_tag()
 
 
-class Mat5CompositeWriter(Mat5MatrixWriter):
-    def __init__(self, file_stream, arr, name, is_global=False, unicode_strings=False):
-        super(Mat5CompositeWriter, self).__init__(file_stream, arr, name, is_global)
-        self.unicode_strings = unicode_strings
-
-
-class Mat5CellWriter(Mat5CompositeWriter):
+class Mat5CellWriter(Mat5MatrixWriter):
     def write(self):
         self.write_header(mclass=mxCELL_CLASS)
         # loop over data, column major
@@ -855,10 +786,8 @@
             MW.write()
         self.update_matrix_tag()
 
-class Mat5FunctionWriter(Mat5CompositeWriter):
-    def __init__(self, file_stream, arr, name, is_global=False, unicode_strings=False):
-        super(Mat5FunctionWriter, self).__init__(file_stream, arr.arr, name, is_global)
 
+class Mat5FunctionWriter(Mat5MatrixWriter):
     def write(self):
         self.write_header(mclass=mxFUNCTION_CLASS)
         # loop over data, column major
@@ -870,48 +799,35 @@
         self.update_matrix_tag()
 
 
-class Mat5StructWriter(Mat5CompositeWriter):
+class Mat5StructWriter(Mat5MatrixWriter):
     def write(self):
         self.write_header(mclass=mxSTRUCT_CLASS)
+        self.write_fields()
 
+    def write_fields(self):
         # write fieldnames
         fieldnames = [f[0] for f in self.arr.dtype.descr]
         self.write_element(np.array([32], dtype='i4'))
-        self.write_element(np.array(fieldnames, dtype='S32'), mdtype=miINT8)
-
+        self.write_element(np.array(fieldnames, dtype='S32'),
+                           mdtype=miINT8)
         A = np.atleast_2d(self.arr).flatten('F')
-        MWG = Mat5WriterGetter(self.file_stream, self.unicode_strings)
+        MWG = Mat5WriterGetter(self.file_stream,
+                               self.unicode_strings)
         for el in A:
             for f in fieldnames:
                 MW = MWG.matrix_writer_factory(el[f], '')
                 MW.write()
         self.update_matrix_tag()
 
-class Mat5ObjectWriter(Mat5CompositeWriter):
-    def __init__(self, file_stream, arr, name, is_global=False, unicode_strings=False):
-        super(Mat5ObjectWriter, self).__init__(file_stream, arr.__dict__['mobj_recarray'], name, is_global)
-        self.classname = arr.classname
 
+class Mat5ObjectWriter(Mat5StructWriter):
     def write(self):
         self.write_header(mclass=mxOBJECT_CLASS)
+        self.write_element(np.array(self.arr.classname, dtype='S'),
+                           mdtype=miINT8)
+        self.write_fields()
 
-        # write classnames
-        self.write_element(np.array(self.classname, dtype='S'), mdtype=miINT8)
 
-        # write fieldnames
-        fieldnames = [f[0] for f in self.arr.dtype.descr]
-        self.write_element(np.array([32], dtype='i4'))
-        self.write_element(np.array(fieldnames, dtype='S32'), mdtype=miINT8)
-
-        A = np.atleast_2d(self.arr).flatten('F')
-        MWG = Mat5WriterGetter(self.file_stream, self.unicode_strings)
-        for el in A:
-            for f in fieldnames:
-                MW = MWG.matrix_writer_factory(el[f], '')
-                MW.write()
-        self.update_matrix_tag()
-
-
 class Mat5WriterGetter(object):
     ''' Wraps stream and options, provides methods for getting Writer objects '''
     def __init__(self, stream, unicode_strings):
@@ -923,33 +839,46 @@
 
     def matrix_writer_factory(self, arr, name, is_global=False):
         ''' Factory function to return matrix writer given variable to write
-        stream      - file or file-like stream to write to
-        arr         - array to write
-        name        - name in matlab (TM) workspace
+
+        Parameters
+        ----------
+        arr : array-like
+            array-like object to create writer for
+        name : string
+            name as it will appear in matlab workspace
+        is_global : {False, True} optional
+            whether variable will be global on load into matlab
         '''
+        # First check if these are sparse
         if spsparse:
             if spsparse.issparse(arr):
                 return Mat5SparseWriter(self.stream, arr, name, is_global)
-
-        if isinstance(arr, MatlabFunctionMatrix):
-            return Mat5FunctionWriter(self.stream, arr, name, is_global, self.unicode_strings)
-        if isinstance(arr, MatlabObject):
-            return Mat5ObjectWriter(self.stream, arr, name, is_global, self.unicode_strings)
-
-        arr = np.array(arr)
-        if arr.dtype.hasobject:
-            if arr.dtype.fields == None:
-                return Mat5CellWriter(self.stream, arr, name, is_global, self.unicode_strings)
+        # Next try and convert to an array
+        narr = np.asanyarray(arr)
+        if narr.dtype.type in (np.object, np.object_) and \
+           narr.size == 1 and narr == arr:
+            # No interesting conversion possible
+            raise TypeError('Could not convert %s (type %s) to array'
+                            % (arr, type(arr)))
+        args = (self.stream, narr, name, is_global, self.unicode_strings)
+        if isinstance(narr, MatlabFunction):
+            return Mat5FunctionWriter(*args)
+        if isinstance(narr, MatlabObject):
+            return Mat5ObjectWriter(*args)
+        if narr.dtype.hasobject: # cell or struct array
+            if narr.dtype.fields == None:
+                return Mat5CellWriter(*args)
             else:
-                return Mat5StructWriter(self.stream, arr, name, is_global, self.unicode_strings)
-        if arr.dtype.kind in ('U', 'S'):
+                return Mat5StructWriter(*args)
+        if narr.dtype.kind in ('U', 'S'):
             if self.unicode_strings:
-                return Mat5UniCharWriter(self.stream, arr, name, is_global)
+                return Mat5UniCharWriter(*args)
             else:
-                return Mat5CharWriter(self.stream, arr, name, is_global)
+                return Mat5CharWriter(*args)
         else:
-            return Mat5NumericWriter(self.stream, arr, name, is_global)
+            return Mat5NumericWriter(*args)
 
+
 class MatFile5Writer(MatFileWriter):
     ''' Class for writing mat5 files '''
     def __init__(self, file_stream,

Modified: trunk/scipy/io/matlab/tests/test_mio.py
===================================================================
--- trunk/scipy/io/matlab/tests/test_mio.py	2008-11-11 07:32:57 UTC (rev 5058)
+++ trunk/scipy/io/matlab/tests/test_mio.py	2008-11-11 09:03:27 UTC (rev 5059)
@@ -1,6 +1,8 @@
 #!/usr/bin/env python
 ''' Nose test generators
 
+Need function load / save / roundtrip tests
+
 '''
 from os.path import join, dirname
 from glob import glob
@@ -9,6 +11,7 @@
 import warnings
 import shutil
 import gzip
+import copy
 
 from numpy.testing import \
      assert_array_almost_equal, \
@@ -151,15 +154,17 @@
     {'name': 'structarr',
      'expected': {'teststructarr': a}
      })
-MO = MatlabObject('inline',
+ODT = np.dtype([(n, object) for n in
                  ['expr', 'inputExpr', 'args',
-                  'isEmpty', 'numArgs', 'version'])
-MO.expr = u'x'
-MO.inputExpr = u' x = INLINE_INPUTS_{1};'
-MO.args = u'x'
-MO.isEmpty = mlarr(0)
-MO.numArgs = mlarr(1)
-MO.version = mlarr(1)
+                  'isEmpty', 'numArgs', 'version']])
+MO = MatlabObject(np.zeros((1,1), dtype=ODT), 'inline')
+m0 = MO[0,0]
+m0['expr'] = array([u'x'])
+m0['inputExpr'] = array([u' x = INLINE_INPUTS_{1};'])
+m0['args'] = array([u'x'])
+m0['isEmpty'] = mlarr(0)
+m0['numArgs'] = mlarr(1)
+m0['version'] = mlarr(1)
 case_table5.append(
     {'name': 'object',
      'expected': {'testobject': MO}
@@ -171,7 +176,8 @@
     {'name': 'unicode',
     'expected': {'testunicode': array([u_str])}
     })
-# These should also have matlab load equivalents, but I can't get to matlab at the moment
+# These should also have matlab load equivalents,
+# but I can't get to matlab at the moment
 case_table5_rt = case_table5[:]
 case_table5_rt.append(
     {'name': 'sparsefloat',
@@ -183,8 +189,10 @@
      'expected': {'testsparsecomplex':
                   SP.coo_matrix(array([[-1+2j,0,2],[0,-3j,0]]))},
      })
+case_table5_rt.append(
+    {'name': 'objectarray',
+     'expected': {'testobjectarray': np.repeat(MO, 2).reshape(1,2)}})
 
-
 def _check_level(label, expected, actual):
     """ Check one level of a potentially nested array """
     if SP.issparse(expected): # allow different types of sparse matrices
@@ -199,26 +207,13 @@
     typac = type(actual)
     assert typex is typac, \
            "Expected type %s, got %s at %s" % (typex, typac, label)
-    # object, as container for matlab objects
-    if isinstance(expected, MatlabObject):
-        ex_fields = dir(expected)
-        ac_fields = dir(actual)
-        for k in ex_fields:
-            if k.startswith('__') and k.endswith('__'):
-                continue
-            assert k in ac_fields, \
-                   "Missing expected property %s for %s" % (k, label)
-            ev = expected.__dict__[k]
-            v = actual.__dict__[k]
-            level_label = "%s, property %s, " % (label, k)
-            _check_level(level_label, ev, v)
-        return
     # A field in a record array may not be an ndarray
     # A scalar from a record array will be type np.void
-    if not isinstance(expected, (np.void, np.ndarray)): 
+    if not isinstance(expected,
+                      (np.void, np.ndarray, MatlabObject)): 
         assert_equal(expected, actual)
         return
-    # This is an ndarray
+    # This is an ndarray-like thing
     assert_true(expected.shape == actual.shape,
                 msg='Expected shape %s, got %s at %s' % (expected.shape,
                                                          actual.shape,
@@ -226,6 +221,8 @@
                 )
     ex_dtype = expected.dtype
     if ex_dtype.hasobject: # array of objects
+        if isinstance(expected, MatlabObject):
+            assert_equal(expected.classname, actual.classname)
         for i, ev in enumerate(expected):
             level_label = "%s, [%d], " % (label, i)
             _check_level(level_label, ev, actual[i])
@@ -311,7 +308,10 @@
         join(test_data_path, 'testhdf5*.mat'))
     assert len(filenames)
     for filename in filenames:
-        assert_raises(NotImplementedError, loadmat, filename, struct_as_record=True)
+        assert_raises(NotImplementedError,
+                      loadmat,
+                      filename,
+                      struct_as_record=True)
 
 
 def test_warnings():
@@ -326,20 +326,16 @@
     # This too
     yield assert_raises, FutureWarning, find_mat_file, fname
     # we need kwargs for this one
-    try:
-        mres = loadmat(fname, struct_as_record=False, basename='raw')
-    except DeprecationWarning:
-        pass
-    else:
-        assert False, 'Did not raise deprecation warning'
+    yield (lambda a, k: assert_raises(*a, **k), 
+          (DeprecationWarning, loadmat, fname), 
+          {'struct_as_record':True, 'basename':'raw'})
     # Test warning for default format change
     savemat(StringIO(), {}, False, '4')
     savemat(StringIO(), {}, False, '5')
     yield assert_raises, FutureWarning, savemat, StringIO(), {}
     warnings.resetwarnings()
 
-@dec.knownfailureif(True, "Infinite recursion when writing a simple "\
-                          "dictionary to matlab file.")
+
 def test_regression_653():
     """Regression test for #653."""
-    savemat(StringIO(), {'d':{1:2}}, format='5')
+    assert_raises(TypeError, savemat, StringIO(), {'d':{1:2}}, format='5')



More information about the Scipy-svn mailing list