[Numpy-svn] r4142 - in trunk/numpy/lib: . tests

numpy-svn@scip... numpy-svn@scip...
Sun Sep 30 05:41:44 CDT 2007


Author: stefan
Date: 2007-09-30 05:41:27 -0500 (Sun, 30 Sep 2007)
New Revision: 4142

Modified:
   trunk/numpy/lib/tests/test_twodim_base.py
   trunk/numpy/lib/twodim_base.py
Log:
Fix tri when dtype is bool (closes ticket #574).


Modified: trunk/numpy/lib/tests/test_twodim_base.py
===================================================================
--- trunk/numpy/lib/tests/test_twodim_base.py	2007-09-26 09:01:40 UTC (rev 4141)
+++ trunk/numpy/lib/tests/test_twodim_base.py	2007-09-30 10:41:27 UTC (rev 4142)
@@ -5,7 +5,7 @@
 from numpy.testing import *
 set_package_path()
 from numpy import arange, rot90, add, fliplr, flipud, zeros, ones, eye, \
-     array, diag, histogram2d
+     array, diag, histogram2d, tri
 import numpy as np
 restore_path()
 
@@ -160,7 +160,7 @@
         assert_array_equal(H, eye(10,10))
         assert_array_equal(xedges, np.linspace(0,9,11))
         assert_array_equal(yedges, np.linspace(0,9,11))
-        
+
     def check_asym(self):
         x = array([1, 1, 2, 3, 4, 4, 4, 5])
         y = array([1, 3, 2, 0, 1, 2, 3, 4])
@@ -187,6 +187,14 @@
         r = rand(100)+1.
         H, xed, yed = histogram2d(r, r, (4, 5), range=([0,1], [0,1]))
         assert_array_equal(H, 0)
-        
+
+class test_tri(NumpyTestCase):
+    def test_dtype(self):
+        out = array([[1,0,0],
+                     [1,1,0],
+                     [1,1,1]])
+        assert_array_equal(tri(3),out)
+        assert_array_equal(tri(3,dtype=bool),out.astype(bool))
+
 if __name__ == "__main__":
     NumpyTest().run()

Modified: trunk/numpy/lib/twodim_base.py
===================================================================
--- trunk/numpy/lib/twodim_base.py	2007-09-26 09:01:40 UTC (rev 4141)
+++ trunk/numpy/lib/twodim_base.py	2007-09-30 10:41:27 UTC (rev 4142)
@@ -108,8 +108,7 @@
     """
     if M is None: M = N
     m = greater_equal(subtract.outer(arange(N), arange(M)),-k)
-    if m.dtype != dtype:
-        return m.astype(dtype)
+    return m.astype(dtype)
 
 def tril(m, k=0):
     """ returns the elements on and below the k-th diagonal of m.  k=0 is the



More information about the Numpy-svn mailing list