[Numpy-svn] r3319 - in trunk/numpy/core: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Thu Oct 12 14:19:08 CDT 2006


Author: tim_hochberg
Date: 2006-10-12 14:19:04 -0500 (Thu, 12 Oct 2006)
New Revision: 3319

Modified:
   trunk/numpy/core/numeric.py
   trunk/numpy/core/tests/test_numeric.py
Log:
Added docstring and tests to errstate. Also added 'all' option for seterr so that we can set all the options at once. Note that tests on errstate are only run in Python 2.5 and higher.

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2006-10-12 19:16:55 UTC (rev 3318)
+++ trunk/numpy/core/numeric.py	2006-10-12 19:19:04 UTC (rev 3319)
@@ -617,11 +617,14 @@
     _errdict_rev[_errdict[key]] = key
 del key
 
-def seterr(divide=None, over=None, under=None, invalid=None):
+def seterr(all=None, divide=None, over=None, under=None, invalid=None):
     """Set how floating-point errors are handled.
 
     Valid values for each type of error are the strings
     "ignore", "warn", "raise", and "call". Returns the old settings.
+    If 'all' is specified, values that are not otherwise specified
+    will be set to 'all', otherwise they will retain their old
+    values.
 
     Note that operations on integer scalar types (such as int16) are
     handled like floating point, and are affected by these settings.
@@ -630,19 +633,24 @@
 
     >>> seterr(over='raise')
     {'over': 'ignore', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'}
+    >>> seterr(all='warn', over='raise')
+    {'over': 'raise', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'}
     >>> int16(32000) * int16(3)
     Traceback (most recent call last):
           File "<stdin>", line 1, in ?
     FloatingPointError: overflow encountered in short_scalars
+    >>> seterr(all='ignore')
+    {'over': 'ignore', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'}
+    
     """
 
     pyvals = umath.geterrobj()
     old = geterr()
 
-    if divide is None: divide = old['divide']
-    if over is None: over = old['over']
-    if under is None: under = old['under']
-    if invalid is None: invalid = old['invalid']
+    if divide is None: divide = all or old['divide']
+    if over is None: over = all or old['over']
+    if under is None: under = all or old['under']
+    if invalid is None: invalid = all or old['invalid']
 
     maskvalue = ((_errdict[divide] << SHIFT_DIVIDEBYZERO) +
                  (_errdict[over] << SHIFT_OVERFLOW ) +
@@ -653,6 +661,7 @@
     umath.seterrobj(pyvals)
     return old
 
+
 def geterr():
     """Get the current way of handling floating-point errors.
 
@@ -718,12 +727,38 @@
     return umath.geterrobj()[2]
 
 class errstate(object):
+    """with errstate(**state): --> operations in following block use given state.
+    
+    # Set error handling to known state.
+    >>> _ = seterr(invalid='raise', divide='raise', over='raise', under='ignore') 
+    
+    |>> a = -arange(3)
+    |>> with errstate(invalid='ignore'):
+    ...     print sqrt(a)
+    [ 0.     -1.#IND -1.#IND]
+    |>> print sqrt(a.astype(complex))
+    [ 0. +0.00000000e+00j  0. +1.00000000e+00j  0. +1.41421356e+00j]
+    |>> print sqrt(a)
+    Traceback (most recent call last):
+     ...
+    FloatingPointError: invalid encountered in sqrt
+    |>> with errstate(divide='ignore'):
+    ...     print a/0
+    [0 0 0]
+    |>> print a/0
+    Traceback (most recent call last):
+        ...
+    FloatingPointError: divide by zero encountered in divide
+    
+    """
+    # Note that we don't want to run the above doctests because they will fail
+    # without a from __future__ import with_statement
     def __init__(self, **kwargs):
         self.kwargs = kwargs
     def __enter__(self):
         self.oldstate = seterr(**self.kwargs)
     def __exit__(self, *exc_info):
-        numpy.seterr(**self.oldstate)
+        seterr(**self.oldstate)
 
 def _setdef():
     defval = [UFUNC_BUFSIZE_DEFAULT, ERR_DEFAULT, None]

Modified: trunk/numpy/core/tests/test_numeric.py
===================================================================
--- trunk/numpy/core/tests/test_numeric.py	2006-10-12 19:16:55 UTC (rev 3318)
+++ trunk/numpy/core/tests/test_numeric.py	2006-10-12 19:19:04 UTC (rev 3319)
@@ -2,6 +2,7 @@
 from numpy.random import rand, randint
 from numpy.testing import *
 from numpy.core.multiarray import dot as dot_
+import sys
 
 class Vec:
     def __init__(self,sequence=None):
@@ -246,6 +247,12 @@
 
     def test_large(self):
         assert_equal(binary_repr(10736848),'101000111101010011010000')
+        
+import sys
+if sys.version_info[:2] >= (2, 5):
+    set_local_path()
+    from test_errstate import *
+    restore_path()
 
 if __name__ == '__main__':
     NumpyTest().run()



More information about the Numpy-svn mailing list