[Numpy-svn] r6120 - trunk/numpy/lib/tests

numpy-svn@scip... numpy-svn@scip...
Sat Nov 29 06:08:06 CST 2008


Author: stefan
Date: 2008-11-29 06:07:54 -0600 (Sat, 29 Nov 2008)
New Revision: 6120

Modified:
   trunk/numpy/lib/tests/test_io.py
Log:
Add test for load's mmap_mode.

Modified: trunk/numpy/lib/tests/test_io.py
===================================================================
--- trunk/numpy/lib/tests/test_io.py	2008-11-29 12:07:07 UTC (rev 6119)
+++ trunk/numpy/lib/tests/test_io.py	2008-11-29 12:07:54 UTC (rev 6120)
@@ -2,51 +2,86 @@
 import numpy as np
 import StringIO
 
+from tempfile import NamedTemporaryFile
 
 class RoundtripTest:
+    def roundtrip(self, save_func, *args, **kwargs):
+        """
+        save_func : callable
+            Function used to save arrays to file.
+        file_on_disk : bool
+            If true, store the file on disk, instead of in a
+            string buffer.
+        save_kwds : dict
+            Parameters passed to `save_func`.
+        load_kwds : dict
+            Parameters passed to `numpy.load`.
+        args : tuple of arrays
+            Arrays stored to file.
+
+        """
+        save_kwds = kwargs.get('save_kwds', {})
+        load_kwds = kwargs.get('load_kwds', {})
+        file_on_disk = kwargs.get('file_on_disk', False)
+
+        if file_on_disk:
+            target_file = NamedTemporaryFile()
+            load_file = target_file.name
+        else:
+            target_file = StringIO.StringIO()
+            load_file = target_file
+
+        arr = args
+
+        save_func(target_file, *arr, **save_kwds)
+        target_file.flush()
+        target_file.seek(0)
+
+        arr_reloaded = np.load(load_file, **load_kwds)
+
+        self.arr = arr
+        self.arr_reloaded = arr_reloaded
+
     def test_array(self):
-        a = np.array( [[1,2],[3,4]], float)
-        self.do(a)
+        a = np.array([[1, 2], [3, 4]], float)
+        self.roundtrip(a)
 
-        a = np.array( [[1,2],[3,4]], int)
-        self.do(a)
+        a = np.array([[1, 2], [3, 4]], int)
+        self.roundtrip(a)
 
-        a = np.array( [[1+5j,2+6j],[3+7j,4+8j]], dtype=np.csingle)
-        self.do(a)
+        a = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.csingle)
+        self.roundtrip(a)
 
-        a = np.array( [[1+5j,2+6j],[3+7j,4+8j]], dtype=np.cdouble)
-        self.do(a)
+        a = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.cdouble)
+        self.roundtrip(a)
 
     def test_1D(self):
-        a = np.array([1,2,3,4], int)
-        self.do(a)
+        a = np.array([1, 2, 3, 4], int)
+        self.roundtrip(a)
 
+    def test_mmap(self):
+        a = np.array([[1, 2.5], [4, 7.3]])
+        self.roundtrip(a, file_on_disk=True, load_kwds={'mmap_mode': 'r'})
+
     def test_record(self):
         a = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')])
-        self.do(a)
+        self.roundtrip(a)
 
 class TestSaveLoad(RoundtripTest, TestCase):
-    def do(self, a):
-        c = StringIO.StringIO()
-        np.save(c, a)
-        c.seek(0)
-        a_reloaded = np.load(c)
-        assert_equal(a, a_reloaded)
+    def roundtrip(self, *args, **kwargs):
+        RoundtripTest.roundtrip(self, np.save, *args, **kwargs)
+        assert_equal(self.arr[0], self.arr_reloaded)
 
-
 class TestSavezLoad(RoundtripTest, TestCase):
-    def do(self, *arrays):
-        c = StringIO.StringIO()
-        np.savez(c, *arrays)
-        c.seek(0)
-        l = np.load(c)
-        for n, a in enumerate(arrays):
-            assert_equal(a, l['arr_%d' % n])
+    def roundtrip(self, *args, **kwargs):
+        RoundtripTest.roundtrip(self, np.savez, *args, **kwargs)
+        for n, arr in enumerate(self.arr):
+            assert_equal(arr, self.arr_reloaded['arr_%d' % n])
 
     def test_multiple_arrays(self):
         a = np.array( [[1,2],[3,4]], float)
         b = np.array( [[1+2j,2+7j],[3-6j,4+12j]], complex)
-        self.do(a,b)
+        self.roundtrip(a,b)
 
     def test_named_arrays(self):
         a = np.array( [[1,2],[3,4]], float)



More information about the Numpy-svn mailing list