[Numpy-svn] r5497 - in trunk/numpy/lib: . tests

numpy-svn@scip... numpy-svn@scip...
Tue Jul 22 01:37:59 CDT 2008


Author: charris
Date: 2008-07-22 01:37:48 -0500 (Tue, 22 Jul 2008)
New Revision: 5497

Modified:
   trunk/numpy/lib/io.py
   trunk/numpy/lib/tests/test_io.py
Log:
Apply Stefan's patch for Ryan's loadtext fix.


Modified: trunk/numpy/lib/io.py
===================================================================
--- trunk/numpy/lib/io.py	2008-07-22 05:22:08 UTC (rev 5496)
+++ trunk/numpy/lib/io.py	2008-07-22 06:37:48 UTC (rev 5497)
@@ -10,6 +10,7 @@
 import cStringIO
 import tempfile
 import os
+import itertools
 
 from cPickle import load as _cload, loads
 from _datasource import DataSource
@@ -286,46 +287,89 @@
         raise ValueError('fname must be a string or file handle')
     X = []
 
+    def flatten_dtype(dt):
+        """Unpack a structured data-type."""
+        if dt.names is None:
+            return [dt]
+        else:
+            types = []
+            for field in dt.names:
+                tp, bytes = dt.fields[field]
+                flat_dt = flatten_dtype(tp)
+                types.extend(flat_dt)
+            return types
+
+    def split_line(line):
+        """Chop off comments, strip, and split at delimiter."""
+        line = line.split(comments)[0].strip()
+        if line:
+            return line.split(delimiter)
+        else:
+            return []
+
+    # Make sure we're dealing with a proper dtype
     dtype = np.dtype(dtype)
     defconv = _getconv(dtype)
-    converterseq = None
-    if converters is None:
-        converters = {}
-        if dtype.names is not None:
-            if usecols is None:
-                converterseq = [_getconv(dtype.fields[name][0]) \
-                                for name in dtype.names]
-            else:
-                converters.update([(col,_getconv(dtype.fields[name][0])) \
-                                    for col,name in zip(usecols, dtype.names)])
 
-    for i,line in enumerate(fh):
-        if i<skiprows: continue
-        comment_start = line.find(comments)
-        if comment_start != -1:
-            line = line[:comment_start].strip()
-        else:
-            line = line.strip()
-        if not len(line): continue
-        vals = line.split(delimiter)
-        if converterseq is None:
-            converterseq = [converters.get(j,defconv) \
-                            for j in xrange(len(vals))]
-        if usecols is not None:
-            row = [converterseq[j](vals[j]) for j in usecols]
-        else:
-            row = [converterseq[j](val) for j,val in enumerate(vals)]
-        if dtype.names is not None:
-            row = tuple(row)
-        X.append(row)
+    # Skip the first `skiprows` lines
+    for i in xrange(skiprows):
+        fh.readline()
 
-    X = np.array(X, dtype)
+    # Read until we find a line with some values, and use
+    # it to estimate the number of columns, N.
+    read_line = None
+    while not read_line:
+        first_line = fh.readline()
+        read_line = split_line(first_line)
+    N = len(usecols or read_line)
+
+    dtype_types = flatten_dtype(dtype)
+    if len(dtype_types) > 1:
+        # We're dealing with a structured array, each field of
+        # the dtype matches a column
+        converterseq = [_getconv(dt) for dt in dtype_types]
+    else:
+        # All fields have the same dtype
+        converterseq = [defconv for i in xrange(N)]
+
+    # By preference, use the converters specified by the user
+    for i, conv in (converters or {}).iteritems():
+        if usecols:
+            i = usecols.find(i)
+        converterseq[i] = conv
+
+    # Parse each line, including the first
+    for i, line in enumerate(itertools.chain([first_line], fh)):
+        vals = split_line(line)
+        if len(vals) == 0:
+            continue
+
+        if usecols:
+            vals = [vals[i] for i in usecols]
+
+        # Convert each value according to its column and store
+        X.append(tuple(conv(val) for (conv, val) in zip(converterseq, vals)))
+
+    if len(dtype_types) > 1:
+        # We're dealing with a structured array, with a dtype such as
+        # [('x', int), ('y', [('s', int), ('t', float)])]
+        #
+        # First, create the array using a flattened dtype:
+        # [('x', int), ('s', int), ('t', float)]
+        #
+        # Then, view the array using the specified dtype.
+        X = np.array(X, dtype=np.dtype([('', t) for t in dtype_types]))
+        X = X.view(dtype)
+    else:
+        X = np.array(X, dtype)
+
     X = np.squeeze(X)
-    if unpack: return X.T
-    else:  return X
+    if unpack:
+        return X.T
+    else:
+        return X
 
 
-
 def savetxt(fname, X, fmt='%.18e',delimiter=' '):
     """
     Save the data in X to file fname using fmt string to convert the

Modified: trunk/numpy/lib/tests/test_io.py
===================================================================
--- trunk/numpy/lib/tests/test_io.py	2008-07-22 05:22:08 UTC (rev 5496)
+++ trunk/numpy/lib/tests/test_io.py	2008-07-22 06:37:48 UTC (rev 5497)
@@ -232,7 +232,16 @@
         assert_equal(arr['stid'],  ["JOE",  "BOB"])
         assert_equal(arr['temp'],  [25.3,  27.9])
 
+    def test_fancy_dtype(self):
+        c = StringIO.StringIO()
+        c.write('1,2,3.0\n4,5,6.0\n')
+        c.seek(0)
+        dt = np.dtype([('x', int), ('y', [('t', int), ('s', float)])])
+        x = np.loadtxt(c, dtype=dt, delimiter=',')
+        a = np.array([(1,(2,3.0)),(4,(5,6.0))], dt)
+        assert_array_equal(x, a)
 
+
 class Testfromregex(TestCase):
     def test_record(self):
         c = StringIO.StringIO()



More information about the Numpy-svn mailing list