[Scipy-svn] r2157 - in trunk/Lib/io: . tests tests/data

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Aug 12 16:37:55 CDT 2006


Author: stefan
Date: 2006-08-12 16:37:25 -0500 (Sat, 12 Aug 2006)
New Revision: 2157

Added:
   trunk/Lib/io/tests/data/
   trunk/Lib/io/tests/data/test3dmatrix_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testcell_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testcellnest_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testcomplex_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testdouble_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testmatrix_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testminus_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testobject_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testonechar_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testsparse_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/testsparsecomplex_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/teststring_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/teststringarray_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/teststruct_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/teststructarr_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/data/teststructnest_6.5.1_GLNX86.mat
   trunk/Lib/io/tests/test_mio.py
Modified:
   trunk/Lib/io/mio.py
Log:
Improved MATLAB file reading [for Matthew Brett and Nick Fotopoulos].


Modified: trunk/Lib/io/mio.py
===================================================================
--- trunk/Lib/io/mio.py	2006-08-09 22:43:16 UTC (rev 2156)
+++ trunk/Lib/io/mio.py	2006-08-12 21:37:25 UTC (rev 2157)
@@ -2,22 +2,21 @@
 
 # Author: Travis Oliphant
 
-from numpy import squeeze
-from numpy import ndarray
-from numpy import *
-import numpyio
 import struct, os, sys
 import types
+from tempfile import mkstemp
+import zlib
+
+from numpy import array, asarray, empty, obj2sctype, product, reshape, \
+    squeeze, transpose, zeros, vstack, ndarray, shape, diff, where, uint8
+import numpyio
+
 try:
     import scipy.sparse
     have_sparse = 1
 except ImportError:
     have_sparse = 0
 
-if sys.version_info[0] < 2 or sys.version_info[1] < 3:
-    False = 0
-    True = 1
-
 LittleEndian = (sys.byteorder == 'little')
 
 _unit_imag = {'f': array(1j,'F'), 'd': 1j}
@@ -29,7 +28,7 @@
         mtype = 'B'
     elif mtype in ['S1', 'char', 'char*1']:
         mtype = 'B'
-    elif mtype in ['h','schar', 'signed char']:
+    elif mtype in ['b', 'schar', 'signed char']:
         mtype = 'b'
     elif mtype in ['h','short','int16','integer*2']:
         mtype = 'h'
@@ -37,7 +36,7 @@
         mtype = 'H'
     elif mtype in ['i','int']:
         mtype = 'i'
-    elif mtype in ['i','uint','uint32','unsigned int']:
+    elif mtype in ['I','uint','uint32','unsigned int']:
         mtype = 'I'
     elif mtype in ['l','long','int32','integer*4']:
         mtype = 'l'
@@ -55,57 +54,9 @@
     newarr = empty((1,),mtype)
     return newarr.itemsize, newarr.dtype.char
 
-if sys.version[:3] < "2.2":
-    class file:
-        def __init__(self, name, mode='r', bufsize=-1):
-            self.fid = open(name, mode, bufsize)
+class fopen(object):
+    """Class for reading and writing binary files into numpy arrays.
 
-        def close(self):
-            self.fid.close()
-
-        def flush(self):
-            self.fid.flush()
-
-        def fileno(self):
-            return self.fid.fileno()
-
-        def isatty(self):
-            return self.fid.isatty()
-
-        def read(size=-1):
-            return self.fid.read(size)
-
-        def readline(size=-1):
-            return self.fid.readlines()
-
-        def readlines(sizehint=None):
-            if sizehint is None:
-                return self.fid.readlines()
-            else:
-                return self.fid.readlines(sizehint)
-
-        def seek(offset, whence=0):
-            self.fid.seek(offset, whence)
-
-        def tell():
-            return self.fid.tell()
-
-        def truncate(size=None):
-            if size is None:
-                self.fid.truncate()
-            else:
-                self.fid.truncate(size)
-
-        def write(str):
-            self.fid.write(str)
-
-        def writelines(sequence):
-            self.fid.write(sequence)
-
-
-class fopen(file):
-    """Class for reading and writing binary files into Numeric arrays.
-
     Inputs:
 
       file_name -- The complete path name to the file to open.
@@ -127,8 +78,8 @@
 
 #    Methods:
 #
-#      read -- read data from file and return Numeric array
-#      write -- write to file from Numeric array
+#      read -- read data from file and return numpy array
+#      write -- write to file from numpy array
 #      fort_read -- read Fortran-formatted binary data from the file.
 #      fort_write -- write Fortran-formatted binary data to the file.
 #      rewind -- rewind to beginning of file
@@ -136,67 +87,63 @@
 #      seek -- seek to some position in the file
 #      tell -- return current position in file
 #      close -- close the file
-#
-#
-#
 
     def __init__(self,file_name,permission='rb',format='n'):
-        if 'B' not in permission: permission += 'B'
-        if type(file_name) in (types.StringType, types.UnicodeType):
-            file.__init__(self, file_name, permission)
-        elif 'fileno' in file_name.__methods__:  # first argument is an open file
-            self = file_name
-
-        if format in ['native','n','default']:
-            self.__dict__['bs'] = 0
-            self.__dict__['format'] = 'native'
-        elif format in ['ieee-le','l','little-endian','le']:
-            self.__dict__['bs'] = not LittleEndian
-            self.__dict__['format'] = 'ieee-le'
-        elif format in ['ieee-be','B','big-endian','be']:
-            self.__dict__['bs'] = LittleEndian
-            self.__dict__['format'] = 'ieee-be'
+        if 'b' not in permission: permission += 'b'
+        if isinstance(file_name, basestring):
+            self.file = file(file_name, permission)
+        elif isinstance(file_name, file) and not file_name.closed:
+            # first argument is an open file
+            self.file = file_name
         else:
-            raise ValueError, "Unrecognized format: " + format
-
-#    def __setattr__(self, attribute):
-#        raise SyntaxError, "There are no user-settable attributes."
-
+            raise TypeError, 'Need filename or open file as input'
+        self.setformat(format)
+        self.zbuffer = None
+        
     def __del__(self):
         try:
-            self.close()
+            self.file.close()
         except:
             pass
 
+    def close(self):
+        self.file.close()
+
+    def seek(self, *args):
+        self.file.seek(*args)
+
+    def tell(self):
+        self.file.tell()
+        
     def raw_read(self, size=-1):
         """Read raw bytes from file as string."""
-        return file.read(self, size)
+        return self.file.read(size)
 
     def raw_write(self, str):
         """Write string to file as raw bytes."""
-        return file.read(self, str)
+        return self.file.write(str)
 
     def setformat(self, format):
         """Set the byte-order of the file."""
         if format in ['native','n','default']:
-            self.__dict__['bs'] = False
-            self.__dict__['format'] = 'native'
+            self.bs = False
+            self.format = 'native'
         elif format in ['ieee-le','l','little-endian','le']:
-            self.__dict__['bs'] = not LittleEndian
-            self.__dict__['format'] = 'ieee-le'
+            self.bs = not LittleEndian
+            self.format = 'ieee-le'
         elif format in ['ieee-be','B','big-endian','be']:
-            self.__dict__['bs'] = LittleEndian
-            self.__dict__['format'] = 'ieee-be'
+            self.bs = LittleEndian
+            self.format = 'ieee-be'
         else:
             raise ValueError, "Unrecognized format: " + format
         return
 
     def write(self,data,mtype=None,bs=None):
-        """Write to open file object the flattened Numeric array data.
+        """Write to open file object the flattened numpy array data.
 
         Inputs:
 
-          data -- the Numeric array to write.
+          data -- the numpy array to write.
           mtype -- a string indicating the binary type to write.
                    The default is the type of data. If necessary a cast is made.
                    unsigned byte  : 'B', 'uchar', 'byte' 'unsigned char', 'int8',
@@ -227,13 +174,13 @@
             mtype = data.dtype.char
         howmany,mtype = getsize_type(mtype)
         count = product(data.shape)
-        numpyio.fwrite(self,count,data,mtype,bs)
+        numpyio.fwrite(self.file,count,data,mtype,bs)
         return
 
     fwrite = write
 
     def read(self,count,stype,rtype=None,bs=None,c_is_b=0):
-        """Read data from file and return it in a Numeric array.
+        """Read data from file and return it in a numpy array.
 
         Inputs:
 
@@ -249,7 +196,7 @@
 
         Outputs: (output,)
 
-          output -- a Numeric array of type rtype.
+          output -- a numpy array of type rtype.
         """
         if bs is None:
             bs = self.bs
@@ -291,9 +238,7 @@
             howmany,rtype = getsize_type(rtype)
         if count == 0:
             return zeros(0,rtype)
-        retval = numpyio.fread(self, count, stype, rtype, bs)
-        if len(retval) == 1:
-            retval = retval[0]
+        retval = numpyio.fread(self.file, count, stype, rtype, bs)
         if shape is not None:
             retval = resize(retval, shape)
         return retval
@@ -301,7 +246,7 @@
     fread = read
 
     def rewind(self,howmany=None):
-        """Rewind a file to it's beginning or by a specified amount.
+        """Rewind a file to its beginning or by a specified amount.
         """
         if howmany is None:
             self.seek(0)
@@ -318,7 +263,7 @@
             self.seek(0,2)
             sz = self.tell()
             self.seek(curpos)
-            self.__dict__['thesize'] = sz
+            self.thesize = sz
         return sz
 
     def fort_write(self,fmt,*args):
@@ -341,7 +286,7 @@
             nfmt = ">i"
         else:
             nfmt = "i"
-        if type(fmt) in (types.StringType, types.UnicodeType):
+        if isinstance(fmt, basestring):
             if self.format == 'ieee-le':
                 fmt = "<"+fmt
             elif self.format == 'ieee-be':
@@ -359,7 +304,7 @@
             count = product(fmt.shape)
             strlen = struct.pack(nfmt,count*sz)
             self.write(strlen)
-            numpyio.fwrite(self.fid,count,fmt,mtype,self.bs)
+            numpyio.fwrite(self.file,count,fmt,mtype,self.bs)
             self.write(strlen)
         else:
             raise TypeError, "Unknown type in first argument"
@@ -412,14 +357,42 @@
                 raise ValueError, "Negative number of bytes to read:\n    file is probably not opened with correct endian-ness."
             if ncount == 0:
                 raise ValueError, "End of file?  Zero-bytes to read."
-            retval = numpyio.fread(self, ncount, dtype, dtype, self.bs)
+            retval = numpyio.fread(self.file, ncount, dtype, dtype, self.bs)
             if len(retval) == 1:
                 retval = retval[0]
             if (self.raw_read(nn) == ''):
                 raise ValueError, "Unexpected end of file..."
             return retval
+        
 
+class CompressedFopen(fopen):
+    """ File container for temporary buffer to decompress data """
+    def __init__(self, *args, **kwargs):
+        fd, fname = mkstemp()
+        super(CompressedFopen, self).__init__(
+            os.fdopen(fd, 'w+b'), *args, **kwargs)
+        self.file_name = fname
+        
+    def fill(self, bytes):
+        """ Uncompress buffer in @bytes and write to file """
+        self.rewind()
+        self.raw_write(zlib.decompress(bytes))
+        self.rewind()
 
+    def __del__(self):
+        try:
+            self.file.truncate(0)
+        except:
+            pass
+        try:
+            self.close()
+        except:
+            pass
+        try:
+            os.remove(self.file_name)
+        except:
+            pass
+        
 #### MATLAB Version 5 Support ###########
 
 # Portions of code borrowed and (heavily) adapted
@@ -465,6 +438,10 @@
 miINT64 =12
 miUINT64 = 13
 miMATRIX = 14
+miCOMPRESSED = 15
+miUTF8 = 16
+miUTF16 = 17
+miUTF32 = 18
 
 miNumbers = (
     miINT8,
@@ -491,8 +468,23 @@
     miINT64 : ('miINT64',8,'q'),
     miUINT64 : ('miUINT64',8,'Q'),
     miMATRIX : ('miMATRIX',0,None),
+    miUTF8 : ('miUTF8',1,'b'),
+    miUTF16 : ('miUTF16',2,'h'),
+    miUTF32 : ('miUTF32',4,'l'),
     }
 
+''' Before release v7.1 (release 14) matlab used the system default
+character encoding scheme padded out to 16-bits. Release 14 and later
+use Unicode. When saving character data, matlab R14 checks if it can
+be encoded in 7-bit ascii, and saves in that format if so.'''
+miCodecs = {
+    miUINT8: 'ascii',
+    miUINT16: sys.getdefaultencoding(),
+    miUTF8: 'utf8',
+    miUTF16: 'utf16',
+    miUTF32: 'utf32',
+    } 
+
 mxCELL_CLASS = 1
 mxSTRUCT_CLASS = 2
 mxOBJECT_CLASS = 3
@@ -518,8 +510,8 @@
     mxINT32_CLASS,
     mxUINT32_CLASS,
     )
-
-def _parse_header(fid, dict):
+    
+def _parse_header(fid, hdict):
     correct_endian = (ord('M')<<8) + ord('I')
                  # if this number is read no BS
     fid.seek(126)  # skip to endian detector
@@ -531,12 +523,19 @@
         else: openstr = 'l'
     fid.setformat(openstr)  # change byte-order if necessary
     fid.rewind()
-    dict['__header__'] = fid.raw_read(124).strip(' \t\n\000')
+    hdict['__header__'] = fid.raw_read(124).strip(' \t\n\000')
     vers = fid.read(1,'int16')
-    dict['__version__'] = '%d.%d' % (vers >> 8, vers & 255)
+    hdict['__version__'] = '%d.%d' % (vers >> 8, vers & 0xFF)
     fid.seek(2,1)  # move to start of data
     return
 
+def _skip_padding(fid, numbytes, rowsize):
+    """ Skip to next row or @rowsize after previous read of @numbytes """
+    mod = numbytes % rowsize
+    if mod:
+        skip = rowsize-mod
+        fid.seek(skip,1)
+
 def _parse_array_flags(fid):
     # first 8 bytes are always miUINT32 and 8 --- just a check
     dtype, nbytes = fid.read(2,'I')
@@ -545,8 +544,8 @@
 
     # read array flags.
     rawflags = fid.read(2,'I')
-    class_ = rawflags[0] & 255
-    flags = (rawflags[0] & 65535) >> 8
+    class_ = rawflags[0] & 0xFF
+    flags = (rawflags[0] & 0xFFFF) >> 8
     # Global and logical fields are currently ignored
     if (flags & 8): cmplx = 1
     else: cmplx = 0
@@ -558,16 +557,34 @@
 
 def _parse_mimatrix(fid,bytes):
     dclass, cmplx, nzmax =_parse_array_flags(fid)
-    dims = _get_element(fid)[0]
-    name = asarray(_get_element(fid)[0]).tostring()
+    dims = _get_element(fid)
+    name = _get_element(fid).tostring()
     tupdims = tuple(dims[::-1])
     if dclass in mxArrays:
-        result, unused =_get_element(fid)
+        result, unused, dtype =_get_element(fid, return_name_dtype=True)
         if dclass == mxCHAR_CLASS:
-            result = ''.join(asarray(result).astype('S1'))
+            en = miCodecs[dtype]
+            try:
+                " ".encode(en)
+            except LookupError:
+                raise ValueError, 'Character encoding %s not supported' % en
+            if dtype == miUINT16:
+                char_len = len("  ".encode(en)) - len(" ".encode(en))
+                if char_len == 1: # Need to downsample from 16 bit
+                    result = result.astype(uint8)
+            result = squeeze(transpose(reshape(result,tupdims)))
+            dims = result.shape
+            if len(dims) >= 2: # return array of strings
+                n_dims = dims[:-1]
+                string_arr = reshape(result, (product(n_dims), dims[-1]))
+                result = empty(n_dims, dtype=object)
+                for i in range(0, n_dims[-1]):
+                    result[...,i] = string_arr[i].tostring().decode(en)
+            else: # return string
+                result = result.tostring().decode(en)
         else:
             if cmplx:
-                imag, unused =_get_element(fid)
+                imag  =_get_element(fid)
                 try:
                     result = result + _unit_imag[imag.dtype.char] * imag
                 except KeyError:
@@ -576,139 +593,167 @@
 
     elif dclass == mxCELL_CLASS:
         length = product(dims)
-        result = zeros(length, PyObject)
+        result = empty(length, dtype=object)
         for i in range(length):
-            sa, unused = _get_element(fid)
-            result[i]= sa
+            result[i] = _get_element(fid)
         result = squeeze(transpose(reshape(result,tupdims)))
-        if rank(result)==0: result = result.item()
+        if not result.shape:
+            result = result.item()
 
     elif dclass == mxSTRUCT_CLASS:
         length = product(dims)
-        result = zeros(length, PyObject)
-        namelength = _get_element(fid)[0]
+        result = zeros(length, object)
+        namelength = _get_element(fid)
         # get field names
-        names = _get_element(fid)[0]
+        names = _get_element(fid)
         splitnames = [names[i:i+namelength] for i in \
                       xrange(0,len(names),namelength)]
-        fieldnames = [''.join(asarray(x).astype('S1')).strip('\x00')
+        fieldnames = [x.tostring().strip('\x00')
                               for x in splitnames]
         for i in range(length):
             result[i] = mat_struct()
             for element in fieldnames:
-                val,unused = _get_element(fid)
-                result[i].__dict__[element] = val
+                result[i].__dict__[element]  = _get_element(fid)
         result = squeeze(transpose(reshape(result,tupdims)))
-        if rank(result)==0: result = result.item()
+        if not result.shape:
+            result = result.item()
 
         # object is like a structure with but with a class name
     elif dclass == mxOBJECT_CLASS:
-        class_name = ''.join(asarray(_get_element(fid)[0]).astype('S1'))
+        class_name = _get_element(fid).tostring()
         length = product(dims)
-        result = zeros(length, PyObject)
-        namelength = _get_element(fid)[0]
+        result = zeros(length, object)
+        namelength = _get_element(fid)
         # get field names
-        names = _get_element(fid)[0]
+        names = _get_element(fid)
         splitnames = [names[i:i+namelength] for i in \
                       xrange(0,len(names),namelength)]
-        fieldnames = [''.join(asarray(x).astype('S1')).strip('\x00')
+        fieldnames = [x.tostring().strip('\x00')
                               for x in splitnames]
         for i in range(length):
             result[i] = mat_obj()
             result[i]._classname = class_name
             for element in fieldnames:
-                val,unused = _get_element(fid)
-                result[i].__dict__[element] = val
+                result[i].__dict__[element] = _get_element(fid)
         result = squeeze(transpose(reshape(result,tupdims)))
-        if rank(result)==0: result = result.item()
+        if not result.shape:
+            result = result.item()
 
     elif dclass == mxSPARSE_CLASS:
-        rowind, unused = _get_element(fid)
-        colind, unused = _get_element(fid)
-        res, unused = _get_element(fid)
+        rowind  = _get_element(fid)
+        colind = _get_element(fid)
+        res = _get_element(fid)
         if cmplx:
-            imag, unused = _get_element(fid)
+            imag = _get_element(fid)
             try:
                 res = res + _unit_imag[imag.dtype.char] * imag
             except (KeyError,AttributeError):
                 res = res + 1j*imag
+        ''' From the matlab API documentation, last found here:
+        http://www.mathworks.com/access/helpdesk/help/techdoc/matlab_external/
+        @rowind are simply the row indices for all the (@res) non-zero
+        entries in the sparse array.  @rowind has nzmax entries, so
+        may well have more entries than len(@res), the actual number
+        of non-zero entries, but @rowind[len(res):] can be discarded
+        and should be 0. @colind has length (number of columns + 1),
+        and is such that, if D = diff(@colind), D[j] gives the number
+        of non-zero entries in column j. Because @rowind values are
+        stored in column order, this gives the column corresponding to
+        each @rowind
+        '''
+        cols = empty((len(res)), dtype=rowind.dtype)
+        col_counts = diff(colind)
+        start_row = 0
+        for i in where(col_counts)[0]:
+            end_row = start_row + col_counts[i]
+            cols[start_row:end_row] = i
+            start_row = end_row
+        ij = vstack((rowind[:len(res)], cols))
         if have_sparse:
-            spmat = scipy.sparse.csc_matrix(res, (rowind[:len(res)], colind),
-                                            M=dims[0],N=dims[1])
-            result = spmat
+            result = scipy.sparse.csc_matrix((res,ij), [dims[0],dims[1]])
         else:
-            result = (dims, rowind, colind, res)
+            result = (dims, ij, res)
 
     return result, name
 
 # Return a Python object for the element
-def _get_element(fid):
+def _get_element(fid, return_name_dtype=False):
+    """ Return a python object from next element in @fid
 
+    @fid    - fopen object for matfile
+    @return_name_dtype - if True, return tuple of (element, name, dtype)
+                         if False, return element only
+    """
+    name = None
     test = fid.raw_read(1)
     if len(test) == 0:  # nothing left
         raise EOFError
     else:
         fid.rewind(1)
     # get the data tag
-    raw_tag = fid.read(1,'I')
-
-    # check for compressed
+    raw_tag = int(fid.read(1,'I'))
+    
+    # check for small data element format
     numbytes = raw_tag >> 16
-    if numbytes > 0:  # compressed format
+    if numbytes > 0:  # small data element format
         if numbytes > 4:
             raise IOError, "Problem with MAT file: " \
-                  "too many bytes in compressed format."
-        dtype = raw_tag & 65535
+                  "too many bytes in small data element format."
+        dtype = int(raw_tag & 0xFFFF)
         el = fid.read(numbytes,miDataTypes[dtype][2],c_is_b=1)
         fid.seek(4-numbytes,1)  # skip padding
-        return el, None
+    else:
+        # otherwise parse tag
+        dtype = raw_tag
+        numbytes = fid.read(1,'I')
+        
+        if dtype == miCOMPRESSED: # compressed data type
+            if not fid.zbuffer:
+                fid.zbuffer = CompressedFopen(format=fid.format)
+            fid.zbuffer.fill(fid.raw_read(numbytes))
+            _skip_padding(fid, numbytes, 8)
+            return _get_element(fid.zbuffer, return_name_dtype)
+        if dtype != miMATRIX:  # basic data type
+            try:
+                el = fid.read(numbytes,miDataTypes[dtype][2],c_is_b=1)
+            except KeyError:
+                raise ValueError, "Unknown data type"
+            _skip_padding(fid, numbytes, 8)
+        else:
+            # handle miMatrix type
+            el, name = _parse_mimatrix(fid,numbytes)
 
-    # otherwise parse tag
-    dtype = raw_tag
-    numbytes = fid.read(1,'I')
-    if dtype != miMATRIX:  # basic data type
-        try:
-            outarr = fid.read(numbytes,miDataTypes[dtype][2],c_is_b=1)
-        except KeyError:
-            raise ValueError, "Unknown data type"
-        mod8 = numbytes%8
-        if mod8:       # skip past padding
-            skip = 8-mod8
-            fid.seek(skip,1)
-        return outarr, None
+    if return_name_dtype:
+        return el, name, dtype
+    return el
 
-    # handle miMatrix type
-    el, name = _parse_mimatrix(fid,numbytes)
-    return el, name
-
 def _loadv5(fid,basename):
-    # return a dictionary from a Matlab version 5 file
+    # return a dictionary from a Matlab version 5-7.1 file
     # always contains the variable __header__
-    dict = {}
-    _parse_header(fid,dict)
+    mdict = {}
+    _parse_header(fid,mdict)
     var = 0
     while 1:  # file pointer to start of next data
         try:
             var = var + 1
-            el, varname = _get_element(fid)
+            el, varname, unused = _get_element(fid, return_name_dtype=True)
             if varname is None:
                 varname = '%s_%04d' % (basename,var)
-            dict[varname] = el
+            mdict[varname] = el
         except EOFError:
             break
-    return dict
+    return mdict
 
 ### END MATLAB v5 support #############
 
-def loadmat(name, dict=None, appendmat=1, basename='raw'):
+def loadmat(name, mdict=None, appendmat=1, basename='raw'):
     """Load the MATLAB(tm) mat file.
 
     If name is a full path name load it in.  Otherwise search for the file
     on the sys.path list and load the first one found (the current directory
     is searched first).
 
-    Both v4 (Level 1.0) and v6 matfiles are supported.  Version 7.0 files
-    are not yet supported.
+    v4 (Level 1.0), v6 and v7.1 matfiles are supported.  
 
     Inputs:
 
@@ -754,13 +799,13 @@
     if not (0 in test_vals):       # MATLAB version 5 format
         fid.rewind()
         thisdict = _loadv5(fid,basename)
-        if dict is not None:
-            dict.update(thisdict)
+        if mdict is not None:
+            mdict.update(thisdict)
             return
         else:
             return thisdict
-
-
+        
+    # The remainder of this function is the v4 codepath
     testtype = struct.unpack('i',test_vals.tostring())
     # Check to see if the number is positive and less than 5000.
     if testtype[0] < 0 or testtype[0] > 4999:
@@ -816,12 +861,13 @@
             data = atleast_1d(fid.fread(numels,storage))
             if header[3]:  # imaginary data
                 data2 = fid.fread(numels,storage)
-                new = zeros(data.shape,data.dtype.char.capitalize())
-                new.real = data
-                new.imag = data2
-                data = new
-                del(new)
-                del(data2)
+                if data.dtype.char == 'f' and data2.dtype.char == 'f':
+                    new = empty(data.shape,'F')
+                    new.real = data
+                    new.imag = data2
+                    data = new
+                    del(new)
+                    del(data2)
             if len(data) > 1:
                 data=data.reshape((header[2], header[1])                )
                 thisdict[varname] = transpose(squeeze(data))
@@ -836,14 +882,14 @@
                 thisdict[varname] = data
 
     fid.close()
-    if dict is not None:
+    if mdict is not None:
         print "Names defined = ", defnames
-        dict.update(thisdict)
+        mdict.update(thisdict)
     else:
         return thisdict
 
 
-def savemat(filename, dict):
+def savemat(filename, mdict):
     """Save a dictionary of names and arrays into the MATLAB-style .mat file.
 
     This saves the arrayobjects in the given dictionary to a matlab Version 4
@@ -855,8 +901,8 @@
     fid = fopen(filename,'wb')
     M = not LittleEndian
     O = 0
-    for variable in dict.keys():
-        var = dict[variable]
+    for variable in mdict.keys():
+        var = mdict[variable]
         if not isinstance(var, ndarray):
             continue
         if var.dtype.char == 'S1':

Added: trunk/Lib/io/tests/data/test3dmatrix_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/test3dmatrix_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testcell_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testcell_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testcellnest_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testcellnest_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testcomplex_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testcomplex_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testdouble_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testdouble_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testmatrix_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testmatrix_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testminus_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testminus_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testobject_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testobject_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testonechar_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testonechar_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testsparse_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testsparse_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/testsparsecomplex_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/testsparsecomplex_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/teststring_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/teststring_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/teststringarray_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/teststringarray_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/teststruct_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/teststruct_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/teststructarr_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/teststructarr_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/data/teststructnest_6.5.1_GLNX86.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/io/tests/data/teststructnest_6.5.1_GLNX86.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/io/tests/test_mio.py
===================================================================
--- trunk/Lib/io/tests/test_mio.py	2006-08-09 22:43:16 UTC (rev 2156)
+++ trunk/Lib/io/tests/test_mio.py	2006-08-12 21:37:25 UTC (rev 2157)
@@ -0,0 +1,195 @@
+#!/usr/bin/env python
+
+import os
+from glob import glob
+from numpy.testing import set_package_path, restore_path, ScipyTestCase, ScipyTest
+from numpy.testing import assert_equal, assert_array_almost_equal
+from numpy import arange, array, eye, pi, cos, exp, sin, sqrt, ndarray,  \
+     zeros, reshape, transpose, empty
+import scipy.sparse as SP
+
+set_package_path()
+from scipy.io.mio import loadmat, mat_obj, mat_struct
+restore_path()
+
+try:  # Python 2.3 support
+    from sets import Set as set
+except:
+    pass
+
+class test_mio_array(ScipyTestCase):    
+    def __init__(self, *args, **kwargs):
+        super(test_mio_array, self).__init__(*args, **kwargs)
+        self.test_data_path = os.path.join(os.path.dirname(__file__), './data')
+
+    def _check_level(self, label, expected, actual):
+        """ Check one level of a potentially nested dictionary / list """
+        # object array is returned from cell array in mat file
+        if isinstance(expected, ndarray) and expected.dtype.hasobject == 1:
+            assert type(expected) is type(actual), "Different types at %s" % label
+            assert len(expected) == len(actual), "Different list lengths at %s" % label
+            for i, ev in enumerate(expected):
+                level_label = "%s, [%d], " % (label, i)
+                self._check_level(level_label, ev, actual[i])
+            return
+        # object, as container for matlab structs and objects
+        elif isinstance(expected, mat_struct) or isinstance(expected, mat_obj):
+            assert isinstance(actual, type(expected)), \
+                   "Different types %s and %s at %s" % label
+            ex_fields = dir(expected)
+            ac_fields = dir(actual)
+            for k in ex_fields:
+                if k.startswith('__') and k.endswith('__'):
+                    continue
+                assert k in ac_fields, "Missing field at %s" % label
+                ev = expected.__dict__[k]
+                v = actual.__dict__[k]
+                level_label = "%s, field %s, " % (label, k)
+                self._check_level(level_label, ev, v)
+            return
+        # hoping this is a single value, which might be an array
+        if SP.issparse(expected):
+            assert SP.issparse(actual), "Expected sparse at %s" % label
+            assert_array_almost_equal(actual.todense(),
+                                      expected.todense(),
+                                      err_msg = label)
+        elif isinstance(expected, ndarray):
+            assert isinstance(actual, ndarray), "Expected ndarray at %s" % label
+            assert_array_almost_equal(actual, expected, err_msg=label)
+        else:
+            assert isinstance(expected, type(actual)), \
+                   "Types %s and %s do not match at %s" % (type(expected), type(actual), label)
+            assert_equal(actual, expected, err_msg=label)
+    
+    def _check_case(self, name, case):
+        filt = os.path.join(self.test_data_path, 'test%s_*.mat' % name)
+        files = glob(filt)
+        assert files, "No files for test %s using filter %s" % (name, filt)
+        for f in files:
+            matdict = loadmat(f)
+            label = "Test '%s', file:%s" % (name, f)
+            for k, expected in case.items():
+                k_label = "%s, variable %s" % (label, k)
+                assert k in matdict, "Missing key at %s" % k_label
+                self._check_level(k_label, expected, matdict[k])
+
+    # Add the actual tests dynamically, with given parameters
+    def _make_check_case(name, expected):
+        def cc(self):
+            self._check_case(name, expected)
+        cc.__doc__ = "check loadmat case %s" % name
+        return cc
+
+    # Define cases to test
+    theta = pi/4*arange(9,dtype=float)
+    case_table = [
+        {'name': 'double',
+         'expected': {'testdouble': theta}
+         }]
+    case_table.append(
+        {'name': 'string',
+         'expected': {'teststring': u'"Do nine men interpret?" "Nine men," I nod.'},
+         })
+    case_table.append(
+        {'name': 'complex',
+         'expected': {'testcomplex': cos(theta) + 1j*sin(theta)}
+         })
+    case_table.append(
+        {'name': 'cell',
+         'expected': {'testcell':
+                      array([u'This cell contains this string and 3 arrays of '+\
+                             'increasing length',
+                             array([1]), array([1,2]), array([1,2,3])], 
+                            dtype=object)}
+         })
+    st = mat_struct()
+    st.stringfield = u'Rats live on no evil star.'
+    st.doublefield = array([sqrt(2),exp(1),pi])
+    st.complexfield = (1+1j)*array([sqrt(2),exp(1),pi])
+    case_table.append(
+        {'name': 'struct', 
+         'expected': {'teststruct': st}
+         })
+    A = zeros((3,5))
+    A[0] = range(1,6)
+    A[:,0] = range(1,4)
+    case_table.append(
+        {'name': 'matrix',
+         'expected': {'testmatrix': A},
+         })
+    case_table.append(
+        {'name': '3dmatrix',
+         'expected': {'test3dmatrix': transpose(reshape(range(1,25), (4,3,2)))}
+         })
+    case_table.append(
+        {'name': 'sparse',
+         'expected': {'testsparse': SP.csc_matrix(A)},
+         })
+    B = A.astype(complex)
+    B[0,0] += 1j
+    case_table.append(
+        {'name': 'sparsecomplex',
+         'expected': {'testsparsecomplex': SP.csc_matrix(B)},
+         })
+    case_table.append(
+        {'name': 'minus',
+         'expected': {'testminus': array([-1])},
+         })
+    case_table.append(
+        {'name': 'onechar',
+         'expected': {'testonechar': u'r'},
+         })
+    case_table.append(
+        {'name': 'stringarray',
+         'expected': {'teststringarray': array([u'one  ', u'two  ', u'three'], dtype=object)},
+         })
+    case_table.append(
+        {'name': 'cellnest',
+         'expected': {'testcellnest': array([array([1]),
+                                             array([array([2]), array([3]),
+                                                   array([array([4]), array([5])],
+                                                                dtype=object)],
+                                                          dtype=object)],
+                                             dtype=object)},
+         })
+    st = mat_struct()
+    st.one = array([1])
+    st.two = mat_struct()
+    st.two.three = u'number 3'
+    case_table.append(
+        {'name': 'structnest',
+         'expected': {'teststructnest': st}
+         })
+    a = empty((2), dtype=object)
+    a[0], a[1] = mat_struct(), mat_struct()
+    a[0].one = array([1])
+    a[0].two = array([2])
+    a[1].one = u'number 1'
+    a[1].two = u'number 2'
+    case_table.append(
+        {'name': 'structarr',
+         'expected': {'teststructarr': a}
+         })
+
+    a = mat_obj()
+    a._classname = 'inline'
+    a.expr = u'x'
+    a.inputExpr = u' x = INLINE_INPUTS_{1};'
+    a.args = u'x'
+    a.isEmpty = array([0])
+    a.numArgs = array([1])
+    a.version = array([1])
+    case_table.append(
+        {'name': 'object',
+         'expected': {'testobject': a}
+         })
+    
+    # add tests
+    for case in case_table:
+        name = case['name']
+        expected = case['expected']
+        exec 'check_%s = _make_check_case(name, expected)' % name
+
+if __name__ == "__main__":
+    ScipyTest().run()
+



More information about the Scipy-svn mailing list