[Scipy-svn] r2248 - trunk/Lib/io

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Oct 9 08:47:08 CDT 2006


Author: matthew.brett at gmail.com
Date: 2006-10-09 08:47:02 -0500 (Mon, 09 Oct 2006)
New Revision: 2248

Modified:
   trunk/Lib/io/mio5.py
   trunk/Lib/io/miobase.py
Log:
More progress on implementing Mat5 write

Modified: trunk/Lib/io/mio5.py
===================================================================
--- trunk/Lib/io/mio5.py	2006-10-08 02:12:06 UTC (rev 2247)
+++ trunk/Lib/io/mio5.py	2006-10-09 13:47:02 UTC (rev 2248)
@@ -524,8 +524,23 @@
 
 class Mat5MatrixWriter(MatStreamWriter):
 
+    mat_tag = zeros((), mdtypes_template['tag_full'])
+    mat_tag['mdtype'] = miMATRIX
+
+    def __init__(self, file_stream, arr, name, is_global=False):
+        super(Mat5MatrixWriter, self).__init__(file_stream, arr, name)
+        self.is_global = is_global
+
+    def write_dtype(self, arr):
+        self.file_stream.write(arr.tostring)
+
+    def write_element(self, arr):
+        # check if small element works - do it
+        # write tag, data
+        pass
+
     def write_header(self, mclass,
-                     is_global,
+                     is_global=False,
                      is_complex=False,
                      is_logical=False,
                      nzmax=0):
@@ -534,21 +549,27 @@
         @is_global   - True if matrix is global
         @is_complex  - True is matrix is complex
         @is_logical  - True if matrix is logical
+        nzmax        - max non zero elements for sparse arrays
         '''
-        dims = self.arr.shape
-        header = empty((), mdtypes_template['header'])
-        M = not ByteOrder.little_endian
-        O = 0
-        header['mopt'] = (M * 1000 +
-                          O * 100 + 
-                          P * 10 +
-                          T)
-        header['mrows'] = dims[0]
-        header['ncols'] = dims[1]
-        header['imagf'] = imagf
-        header['namlen'] = len(self.name) + 1
-        self.write_bytes(header)
-        self.write_string(self.name + '\0')
+        self._mat_tag_pos = self.file_stream.tell()
+        self.write_dtype(self.mat_tag)
+        # write array flags (complex, global, logical, class, nzmax)
+        af = zeros((), mdtypes_template['array_flags'])
+        af['data_type'] = miUINT32
+        af['byte_count'] = 8
+        flags = is_complex << 3 | is_global << 2 | is_logical << 1
+        af['flags_class'] = mclass | flags << 8
+        af['nzmax'] = nzmax
+        self.write_dtype(af)
+        self.write_element(array(self.arr.shape, dtype='i4'))
+        self.write_element(self.name)
+
+    def update_matrix_tag(self):
+        curr_pos = self.file_stream.tell()
+        self.file_stream.seek(self._mat_tag_pos)
+        self.mat_tag['byte_count'] = curr_pos - self._mat_tag_pos - 8
+        self.write_dtype(self.mat_tag)
+        self.file_stream.seek(curr_pos)
         
     def write(self):
         assert False, 'Not implemented'
@@ -559,10 +580,6 @@
     def write(self):
         # identify matlab type for array
         # make at least 2d
-        # write miMATRIX tag
-        # write array flags (complex, global, logical, class, nzmax)
-        # dimensions
-        # array name
         # maybe downcast array to smaller matlab type
         # write real
         # write imaginary
@@ -611,75 +628,84 @@
                           T=mxSPARSE_CLASS,
                           dims=ijd.shape)
         self.write_bytes(ijd)
-        
-    
-def matrix_writer_factory(stream, arr, name, unicode_strings=False, is_global=False):
-    ''' Factory function to return matrix writer given variable to write
-    @stream      - file or file-like stream to write to
-    @arr         - array to write
-    @name        - name in matlab (TM) workspace
-    '''
-    if have_sparse:
-        if scipy.sparse.issparse(arr):
-            return Mat5SparseWriter(stream, arr, name, is_global)
-    arr = array(arr)
-    if arr.dtype.hasobject:
-        types, arr_type = classify_mobjects(arr)
-        if arr_type == 'c':
-            return Mat5CellWriter(stream, arr, name, is_global, types)
-        elif arr_type == 's':
-            return Mat5StructWriter(stream, arr, name, is_global)
-        elif arr_type == 'o':
-            return Mat5ObjectWriter(stream, arr, name, is_global)
-    if arr.dtype.kind in ('U', 'S'):
-        if unicode_strings:
-            return Mat5UniCharWriter(stream, arr, name, is_global)
+
+
+class Mat5WriterGetter(object):
+    ''' Wraps stream and options, provides methods for getting Writer objects '''
+    def __init__(self, stream, unicode_strings):
+        self.stream = stream
+        self.unicode_strings = unicode_strings
+
+    def rewind(self):
+        self.stream.seek(0)
+
+    def matrix_writer_factory(self, arr, name, is_global=False):
+        ''' Factory function to return matrix writer given variable to write
+        @stream      - file or file-like stream to write to
+        @arr         - array to write
+        @name        - name in matlab (TM) workspace
+        '''
+        if have_sparse:
+            if scipy.sparse.issparse(arr):
+                return Mat5SparseWriter(self.stream, arr, name, is_global)
+        arr = array(arr)
+        if arr.dtype.hasobject:
+            types, arr_type = classify_mobjects(arr)
+            if arr_type == 'c':
+                return Mat5CellWriter(self.stream, arr, name, is_global, types)
+            elif arr_type == 's':
+                return Mat5StructWriter(self.stream, arr, name, is_global)
+            elif arr_type == 'o':
+                return Mat5ObjectWriter(self.stream, arr, name, is_global)
+        if arr.dtype.kind in ('U', 'S'):
+            if self.unicode_strings:
+                return Mat5UniCharWriter(self.stream, arr, name, is_global)
+            else:
+                return Mat5IntCharWriter(self.stream, arr, name, is_global)            
         else:
-            return Mat5IntCharWriter(stream, arr, name, is_global)            
-    else:
-        return Mat5NumericWriter(stream, arr, name, is_global)
+            return Mat5NumericWriter(self.stream, arr, name, is_global)
                     
-def classify_mobjects(objarr):
-    ''' Function to classify objects passed for writing
-    returns
-    types         - S1 array of same shape as objarr with codes for each object
-                    i  - invalid object
-                    a  - ndarray
-                    s  - matlab struct
-                    o  - matlab object
-    arr_type       - one of
-                    c  - cell array
-                    s  - struct array
-                    o  - object array
-    '''
-    N = objarr.size
-    types = empty((N,), dtype='S1')
-    types[:] = 'i'
-    type_set = set()
-    flato = objarr.flat
-    for i in range(N):
-        obj = flato[i]
-        if isinstance(obj, ndarray):
-            types[i] = 'a'
-            continue
-        try:
-            fns = tuple(obj._fieldnames)
-        except AttributeError:
-            continue
-        try:
-            cn = obj._classname
-        except AttributeError:
-            types[i] = 's'
-            type_set.add(fns)
-            continue
-        types[i] = 'o'
-        type_set.add((cn, fns))
-    arr_type = 'c'
-    if len(set(types))==1 and len(type_set) == 1:
-        arr_type = types[0]
-    return types.reshape(objarr.shape), arr_type
-           
-        
+    def classify_mobjects(self, objarr):
+        ''' Function to classify objects passed for writing
+        returns
+        types         - S1 array of same shape as objarr with codes for each object
+                        i  - invalid object
+                        a  - ndarray
+                        s  - matlab struct
+                        o  - matlab object
+        arr_type       - one of
+                        c  - cell array
+                        s  - struct array
+                        o  - object array
+        '''
+        N = objarr.size
+        types = empty((N,), dtype='S1')
+        types[:] = 'i'
+        type_set = set()
+        flato = objarr.flat
+        for i in range(N):
+            obj = flato[i]
+            if isinstance(obj, ndarray):
+                types[i] = 'a'
+                continue
+            try:
+                fns = tuple(obj._fieldnames)
+            except AttributeError:
+                continue
+            try:
+                cn = obj._classname
+            except AttributeError:
+                types[i] = 's'
+                type_set.add(fns)
+                continue
+            types[i] = 'o'
+            type_set.add((cn, fns))
+        arr_type = 'c'
+        if len(set(types))==1 and len(type_set) == 1:
+            arr_type = types[0]
+        return types.reshape(objarr.shape), arr_type
+
+
 class MatFile5Writer(MatFileWriter):
     ''' Class for writing mat5 files '''
     def __init__(self, file_stream,
@@ -688,22 +714,32 @@
                  global_vars=None):
         super(MatFile5Writer, self).__init__(file_stream)
         self.do_compression = do_compression
-        self.unicode_strings = unicode_strings
         if global_vars:
             self.global_vars = global_vars
         else:
             self.global_vars = []
+        self.writer_getter = Mat5WriterGetter(
+            StringIO,
+            unicode_strings)
+
+    def get_unicode_strings(self):
+        return self.write_getter.unicode_strings
+    def set_unicode_strings(self, unicode_strings):
+        self.writer_getter.unicode_strings = unicode_strings
+    unicode_strings = property(get_unicode_strings,
+                               set_unicode_strings,
+                               None,
+                               'get/set unicode strings property')
         
     def put_variables(self, mdict):
         for name, var in mdict.items():
             is_global = name in self.global_vars
-            stream = StringIO()
-            matrix_writer_factory(stream,
-                                  var,
-                                  name,
-                                  is_global,
-                                  self.unicode_strings,
-                                  ).write()
+            self.writer_getter.rewind()
+            self.writer_getter.matrix_writer_factory(
+                var,
+                name,
+                is_global,
+                ).write()
             if self.do_compression:
                 str = zlib.compress(stream.getvalue())
                 tag = empty((), mdtypes_template['tag_full'])

Modified: trunk/Lib/io/miobase.py
===================================================================
--- trunk/Lib/io/miobase.py	2006-10-08 02:12:06 UTC (rev 2247)
+++ trunk/Lib/io/miobase.py	2006-10-09 13:47:02 UTC (rev 2248)
@@ -391,7 +391,27 @@
             self.dt_dict = dt_dict
         self.rtol = rtol
         self.atol = atol
-        
+
+    def eps(self, dt):
+        ''' Calculate machine precision for datatype
+
+        Machine precision defined as difference between X and smallest
+        encodable number greater than X, where X is usually 1.
+
+        Input can be datatype, in which case X=1, or X.
+        '''
+        try:
+            dt = dtype(dt)
+            start = array(1, dt)
+        except TypeError:
+            start = array(dt)
+            dt = start.dtype
+        two = array(2, dt)
+        e = start.copy()
+        while (e / two + start) > start:
+            e = e / two
+        return e
+    
     def default_dt_dict(self):
         d_dict = {}
         for sc_type in ('complex','float'):
@@ -474,10 +494,9 @@
             
     def downcast_complex(self, arr):
         # can we downcast to float?
-        flts = self.storage_criterion(arr.dtype.itemsize / 2,
-                                     ('f'),
-                                      lambda x, y: x <=y)[0]
-        test_arr = arr.astype(flt)
+        fts = self.dt_arrs['float']
+        flts = flts[flts['storage'] <= arr.dtype.itemsize]
+        test_arr = arr.astype(flt[0]['type'])
         if allclose(arr, test_arr, self.rtol, self.atol):
             return self.downcast_float(test_arr)
         # try downcasting to another complex type



More information about the Scipy-svn mailing list