[Scipy-svn] r5029 - in trunk/scipy/io/matlab: . tests

scipy-svn@scip... scipy-svn@scip...
Sat Nov 8 23:32:51 CST 2008


Author: matthew.brett@gmail.com
Date: 2008-11-08 23:32:46 -0600 (Sat, 08 Nov 2008)
New Revision: 5029

Modified:
   trunk/scipy/io/matlab/mio5.py
   trunk/scipy/io/matlab/tests/gen_mat5files.m
   trunk/scipy/io/matlab/tests/test_mio.py
Log:
Added all load checks to roundtrip checks, fixed transpose in save of char data in matlab 5, all 51 tests now pass for me

Modified: trunk/scipy/io/matlab/mio5.py
===================================================================
--- trunk/scipy/io/matlab/mio5.py	2008-11-09 04:13:26 UTC (rev 5028)
+++ trunk/scipy/io/matlab/mio5.py	2008-11-09 05:32:46 UTC (rev 5029)
@@ -728,14 +728,22 @@
                      is_global=False,
                      is_complex=False,
                      is_logical=False,
-                     nzmax=0):
+                     nzmax=0,
+                     shape=None):
         ''' Write header for given data options
         mclass      - mat5 matrix class
         is_global   - True if matrix is global
         is_complex  - True if matrix is complex
         is_logical  - True if matrix is logical
         nzmax        - max non zero elements for sparse arrays
+        shape : {None, tuple} optional
+            directly specify shape if this is not the same as for
+            self.arr
         '''
+        if shape is None:
+            shape = self.arr.shape
+            if len(shape) < 2:
+                shape = shape + (0,) * (len(shape)-2)
         self._mat_tag_pos = self.file_stream.tell()
         self.write_dtype(self.mat_tag)
         # write array flags (complex, global, logical, class, nzmax)
@@ -746,13 +754,7 @@
         af['flags_class'] = mclass | flags << 8
         af['nzmax'] = nzmax
         self.write_dtype(af)
-        # write array shape
-        if self.arr.ndim < 2:
-            new_arr = np.atleast_2d(self.arr)
-            if type(new_arr) != type(self.arr):
-                raise ValueError("Array should be 2-dimensional.")
-            self.arr = new_arr
-        self.write_element(np.array(self.arr.shape, dtype='i4'))
+        self.write_element(np.array(shape, dtype='i4'))
         # write name
         self.write_element(np.array([ord(c) for c in self.name], 'i1'))
 
@@ -786,22 +788,33 @@
             self.write_element(self.arr)
         self.update_matrix_tag()
 
+
 class Mat5CharWriter(Mat5MatrixWriter):
     codec='ascii'
     def write(self):
         self.arr_to_chars()
-        self.write_header(mclass=mxCHAR_CLASS)
+        # We have to write the shape directly, because we are going
+        # recode the characters, and the resulting stream of chars
+        # may have a different length
+        shape = self.arr.shape
+        self.write_header(mclass=mxCHAR_CLASS,shape=shape)
+        # We need to do our own transpose (not using the normal
+        # write routines that do this for us)
+        arr = self.arr.T.copy()
         if self.arr.dtype.kind == 'U':
             # Recode unicode using self.codec
-            n_chars = np.product(self.arr.shape)
+            n_chars = np.product(shape)
             st_arr = np.ndarray(shape=(),
                                 dtype=self.arr_dtype_number(n_chars),
-                                buffer=self.arr)
+                                buffer=arr)
             st = st_arr.item().encode(self.codec)
-            self.arr = np.ndarray(shape=(len(st)), dtype='u1', buffer=st)
-        self.write_element(self.arr,mdtype=miUTF8)
+            arr = np.ndarray(shape=(len(st),),
+                             dtype='u1',
+                             buffer=st)
+        self.write_element(arr, mdtype=miUTF8)
         self.update_matrix_tag()
 
+
 class Mat5UniCharWriter(Mat5CharWriter):
     codec='UTF8'
 
@@ -976,17 +989,20 @@
                 continue
             is_global = name in self.global_vars
             self.writer_getter.rewind()
-            self.writer_getter.matrix_writer_factory(
+            mat_writer = self.writer_getter.matrix_writer_factory(
                 var,
                 name,
-                is_global,
-                ).write()
+                is_global)
+            mat_writer.write()
             stream = self.writer_getter.stream
+            bytes_written = stream.tell()
+            stream.seek(0)
+            out_str = stream.read(bytes_written)
             if self.do_compression:
-                str = zlib.compress(stream.getvalue(stream.tell()))
+                out_str = zlib.compress(out_str)
                 tag = np.empty((), mdtypes_template['tag_full'])
                 tag['mdtype'] = miCOMPRESSED
                 tag['byte_count'] = len(str)
-                self.file_stream.write(tag.tostring() + str)
+                self.file_stream.write(tag.tostring() + out_str)
             else:
-                self.file_stream.write(stream.getvalue(stream.tell()))
+                self.file_stream.write(out_str)

Modified: trunk/scipy/io/matlab/tests/gen_mat5files.m
===================================================================
--- trunk/scipy/io/matlab/tests/gen_mat5files.m	2008-11-09 04:13:26 UTC (rev 5028)
+++ trunk/scipy/io/matlab/tests/gen_mat5files.m	2008-11-09 05:32:46 UTC (rev 5029)
@@ -89,4 +89,8 @@
   fclose(fid);
   save_matfile('testunicode', native2unicode(from_japan, 'utf-8'));
 end
-  
\ No newline at end of file
+  
+% sparse float
+
+
+% sparse complex

Modified: trunk/scipy/io/matlab/tests/test_mio.py
===================================================================
--- trunk/scipy/io/matlab/tests/test_mio.py	2008-11-09 04:13:26 UTC (rev 5028)
+++ trunk/scipy/io/matlab/tests/test_mio.py	2008-11-09 05:32:46 UTC (rev 5029)
@@ -25,87 +25,16 @@
 test_data_path = join(dirname(__file__), 'data')
 
 def mlarr(*args, **kwargs):
-    ''' Return matlab-compatible 2D array'''
+    ''' Convenience function to return matlab-compatible 2D array
+    Note that matlab writes empty shape as (0,0) - replicated here
+    '''
     arr = np.array(*args, **kwargs)
     if arr.size:
         return np.atleast_2d(arr)
     # empty elements return as shape (0,0)
     return arr.reshape((0,0))
 
-def _check_level(label, expected, actual):
-    """ Check one level of a potentially nested array """
-    if SP.issparse(expected): # allow different types of sparse matrices
-        assert SP.issparse(actual)
-        assert_array_almost_equal(actual.todense(),
-                                  expected.todense(),
-                                  err_msg = label,
-                                  decimal = 5)
-        return
-    # Check types are as expected
-    typex = type(expected)
-    typac = type(actual)
-    assert typex is typac, \
-           "Expected type %s, got %s at %s" % (typex, typac, label)
-    # object, as container for matlab objects
-    if isinstance(expected, MatlabObject):
-        ex_fields = dir(expected)
-        ac_fields = dir(actual)
-        for k in ex_fields:
-            if k.startswith('__') and k.endswith('__'):
-                continue
-            assert k in ac_fields, \
-                   "Missing expected property %s for %s" % (k, label)
-            ev = expected.__dict__[k]
-            v = actual.__dict__[k]
-            level_label = "%s, property %s, " % (label, k)
-            _check_level(level_label, ev, v)
-        return
-    # A field in a record array may not be an ndarray
-    # A scalar from a record array will be type np.void
-    if not isinstance(expected, (np.void, np.ndarray)): 
-        assert_equal(expected, actual)
-        return
-    # This is an ndarray
-    assert_true(expected.shape == actual.shape,
-                msg='Expected shape %s, got %s at %s' % (expected.shape,
-                                                         actual.shape,
-                                                         label)
-                )
-    ex_dtype = expected.dtype
-    if ex_dtype.hasobject: # array of objects
-        for i, ev in enumerate(expected):
-            level_label = "%s, [%d], " % (label, i)
-            _check_level(level_label, ev, actual[i])
-        return
-    if ex_dtype.fields: # probably recarray
-        for fn in ex_dtype.fields:
-            level_label = "%s, field %s, " % (label, fn)
-            _check_level(level_label,
-                         expected[fn], actual[fn])
-        return
-    if ex_dtype.type in (np.unicode, # string
-                         np.unicode_):
-        assert_equal(actual, expected, err_msg=label)
-        return
-    # Something numeric
-    assert_array_almost_equal(actual, expected, err_msg=label, decimal=5)
 
-def _check_case(name, files, case):
-    for file_name in files:
-        matdict = loadmat(file_name, struct_as_record=True)
-        label = "test %s; file %s" % (name, file_name)
-        for k, expected in case.items():
-            k_label = "%s, variable %s" % (label, k)
-            assert k in matdict, "Missing key at %s" % k_label
-            _check_level(k_label, expected, matdict[k])
-
-# Round trip tests
-def _rt_check_case(name, expected, format):
-    mat_stream = StringIO()
-    savemat(mat_stream, expected, format=format)
-    mat_stream.seek(0)
-    _check_case(name, [mat_stream], expected)
-
 # Define cases to test
 theta = np.pi/4*np.arange(9,dtype=float).reshape(1,9)
 case_table4 = [
@@ -181,20 +110,6 @@
      'expected': {
     'test3dmatrix': np.transpose(np.reshape(range(1,25), (4,3,2)))}
      })
-case_table5_rt = [
-    {'name': '3dmatrix',
-     'expected': {
-    'test3dmatrix': np.transpose(np.reshape(range(1,25), (4,3,2)))}
-     },
-    {'name': 'sparsefloat',
-     'expected': {'testsparsefloat':
-                  SP.coo_matrix(array([[1,0,2],[0,-3.5,0]]))},
-     },
-    {'name': 'sparsecomplex',
-     'expected': {'testsparsefloat':
-                  SP.coo_matrix(array([[-1+2j,0,2],[0,-3j,0]]))},
-     },
-    ]
 st_sub_arr = array([np.sqrt(2),np.exp(1),np.pi]).reshape(1,3)
 dtype = [(n, object) for n in ['stringfield', 'doublefield', 'complexfield']]
 st1 = np.zeros((1,1), dtype)
@@ -254,7 +169,95 @@
     {'name': 'unicode',
     'expected': {'testunicode': array([u_str])}
     })
+# These should also have matlab load equivalents, but I can't get to matlab at the moment
+case_table5_rt = case_table5[:]
+case_table5_rt.append(
+    {'name': 'sparsefloat',
+     'expected': {'testsparsefloat':
+                  SP.coo_matrix(array([[1,0,2],[0,-3.5,0]]))},
+     })
+case_table5_rt.append(
+    {'name': 'sparsecomplex',
+     'expected': {'testsparsecomplex':
+                  SP.coo_matrix(array([[-1+2j,0,2],[0,-3j,0]]))},
+     })
 
+
+def _check_level(label, expected, actual):
+    """ Check one level of a potentially nested array """
+    if SP.issparse(expected): # allow different types of sparse matrices
+        assert SP.issparse(actual)
+        assert_array_almost_equal(actual.todense(),
+                                  expected.todense(),
+                                  err_msg = label,
+                                  decimal = 5)
+        return
+    # Check types are as expected
+    typex = type(expected)
+    typac = type(actual)
+    assert typex is typac, \
+           "Expected type %s, got %s at %s" % (typex, typac, label)
+    # object, as container for matlab objects
+    if isinstance(expected, MatlabObject):
+        ex_fields = dir(expected)
+        ac_fields = dir(actual)
+        for k in ex_fields:
+            if k.startswith('__') and k.endswith('__'):
+                continue
+            assert k in ac_fields, \
+                   "Missing expected property %s for %s" % (k, label)
+            ev = expected.__dict__[k]
+            v = actual.__dict__[k]
+            level_label = "%s, property %s, " % (label, k)
+            _check_level(level_label, ev, v)
+        return
+    # A field in a record array may not be an ndarray
+    # A scalar from a record array will be type np.void
+    if not isinstance(expected, (np.void, np.ndarray)): 
+        assert_equal(expected, actual)
+        return
+    # This is an ndarray
+    assert_true(expected.shape == actual.shape,
+                msg='Expected shape %s, got %s at %s' % (expected.shape,
+                                                         actual.shape,
+                                                         label)
+                )
+    ex_dtype = expected.dtype
+    if ex_dtype.hasobject: # array of objects
+        for i, ev in enumerate(expected):
+            level_label = "%s, [%d], " % (label, i)
+            _check_level(level_label, ev, actual[i])
+        return
+    if ex_dtype.fields: # probably recarray
+        for fn in ex_dtype.fields:
+            level_label = "%s, field %s, " % (label, fn)
+            _check_level(level_label,
+                         expected[fn], actual[fn])
+        return
+    if ex_dtype.type in (np.unicode, # string
+                         np.unicode_):
+        assert_equal(actual, expected, err_msg=label)
+        return
+    # Something numeric
+    assert_array_almost_equal(actual, expected, err_msg=label, decimal=5)
+
+def _load_check_case(name, files, case):
+    for file_name in files:
+        matdict = loadmat(file_name, struct_as_record=True)
+        label = "test %s; file %s" % (name, file_name)
+        for k, expected in case.items():
+            k_label = "%s, variable %s" % (label, k)
+            assert k in matdict, "Missing key at %s" % k_label
+            _check_level(k_label, expected, matdict[k])
+
+# Round trip tests
+def _rt_check_case(name, expected, format):
+    mat_stream = StringIO()
+    savemat(mat_stream, expected, format=format)
+    mat_stream.seek(0)
+    _load_check_case(name, [mat_stream], expected)
+
+
 # generator for load tests
 def test_load():
     for case in case_table4 + case_table5:
@@ -263,7 +266,7 @@
         filt = join(test_data_path, 'test%s_*.mat' % name)
         files = glob(filt)
         assert files, "No files for test %s using filter %s" % (name, filt)
-        yield _check_case, name, files, expected
+        yield _load_check_case, name, files, expected
 
 # generator for round trip tests
 def test_round_trip():



More information about the Scipy-svn mailing list