[Numpy-svn] r5042 - in trunk/numpy/testing: . tests

numpy-svn@scip... numpy-svn@scip...
Thu Apr 17 15:51:03 CDT 2008


Author: rkern
Date: 2008-04-17 15:51:03 -0500 (Thu, 17 Apr 2008)
New Revision: 5042

Modified:
   trunk/numpy/testing/tests/test_utils.py
   trunk/numpy/testing/utils.py
Log:
Correct dependency on missing code.

Modified: trunk/numpy/testing/tests/test_utils.py
===================================================================
--- trunk/numpy/testing/tests/test_utils.py	2008-04-17 20:23:30 UTC (rev 5041)
+++ trunk/numpy/testing/tests/test_utils.py	2008-04-17 20:51:03 UTC (rev 5042)
@@ -1,10 +1,10 @@
 import numpy as N
 from numpy.testing.utils import *
 
-class _GenericTest:
-    def __init__(self, assert_func):
-        self._assert_func = assert_func
+import unittest
 
+
+class _GenericTest(object):
     def _test_equal(self, a, b):
         self._assert_func(a, b)
 
@@ -47,9 +47,9 @@
 
         self._test_not_equal(a, b)
 
-class TestEqual(_GenericTest):
-    def __init__(self):
-        _GenericTest.__init__(self, assert_array_equal)
+class TestEqual(_GenericTest, unittest.TestCase):
+    def setUp(self):
+        self._assert_func = assert_array_equal
 
     def test_generic_rank1(self):
         """Test rank 1 array for all dtypes."""
@@ -126,6 +126,42 @@
         self._test_not_equal(c, b)
 
 
-class TestAlmostEqual(_GenericTest):
-    def __init__(self):
-        _GenericTest.__init__(self, assert_array_almost_equal)
+class TestAlmostEqual(_GenericTest, unittest.TestCase):
+    def setUp(self):
+        self._assert_func = assert_array_almost_equal
+
+
+class TestRaises(unittest.TestCase):
+    def setUp(self):
+        class MyException(Exception):
+            pass
+
+        self.e = MyException
+
+    def raises_exception(self, e):
+        raise e
+
+    def does_not_raise_exception(self):
+        pass
+
+    def test_correct_catch(self):
+        f = raises(self.e)(self.raises_exception)(self.e)
+
+    def test_wrong_exception(self):
+        try:
+            f = raises(self.e)(self.raises_exception)(RuntimeError)
+        except RuntimeError:
+            return
+        else:
+            raise AssertionError("should have caught RuntimeError")
+
+    def test_catch_no_raise(self):
+        try:
+            f = raises(self.e)(self.does_not_raise_exception)()
+        except AssertionError:
+            return
+        else:
+            raise AssertionError("should have raised an AssertionError")
+
+if __name__ == '__main__':
+    unittest.main()

Modified: trunk/numpy/testing/utils.py
===================================================================
--- trunk/numpy/testing/utils.py	2008-04-17 20:23:30 UTC (rev 5041)
+++ trunk/numpy/testing/utils.py	2008-04-17 20:51:03 UTC (rev 5042)
@@ -294,36 +294,33 @@
     msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
     assert actual==desired, msg
 
-# Ripped from nose.tools
-def raises(*exceptions):
-    """Test must raise one of expected exceptions to pass.
 
-    Example use::
-
-      @raises(TypeError, ValueError)
-      def test_raises_type_error():
-          raise TypeError("This test passes")
-
-      @raises(Exception):
-      def test_that_fails_by_passing():
-          pass
-
-    If you want to test many assertions about exceptions in a single test,
-    you may want to use `assert_raises` instead.
+def raises(*exceptions):
+    """ Assert that a test function raises one of the specified exceptions to
+    pass.
     """
-    valid = ' or '.join([e.__name__ for e in exceptions])
-    def decorate(func):
-        name = func.__name__
-        def newfunc(*arg, **kw):
+    # FIXME: when we transition to nose, just use its implementation. It's
+    # better.
+    def deco(function):
+        def f2(*args, **kwds):
             try:
-                func(*arg, **kw)
+                function(*args, **kwds)
             except exceptions:
                 pass
             except:
+                # Anything else.
                 raise
             else:
-                message = "%s() did not raise %s" % (name, valid)
-                raise AssertionError(message)
-        newfunc = make_decorator(func)(newfunc)
-        return newfunc
-    return decorate
+                raise AssertionError('%s() did not raise one of (%s)' % 
+                    (function.__name__, ', '.join([e.__name__ for e in exceptions])))
+        try:
+            f2.__name__ = function.__name__
+        except TypeError:
+            # Python 2.3 does not permit this.
+            pass
+        f2.__dict__ = function.__dict__
+        f2.__doc__ = function.__doc__
+        f2.__module__ = function.__module__
+        return f2
+
+    return deco



More information about the Numpy-svn mailing list