[Numpy-svn] r5072 - in trunk/numpy/core: . tests

numpy-svn@scip... numpy-svn@scip...
Wed Apr 23 15:32:14 CDT 2008


Author: stefan
Date: 2008-04-23 15:32:02 -0500 (Wed, 23 Apr 2008)
New Revision: 5072

Modified:
   trunk/numpy/core/defmatrix.py
   trunk/numpy/core/tests/test_defmatrix.py
Log:
Hack to let x[0][0] return a scalar for matrices.


Modified: trunk/numpy/core/defmatrix.py
===================================================================
--- trunk/numpy/core/defmatrix.py	2008-04-23 14:46:29 UTC (rev 5071)
+++ trunk/numpy/core/defmatrix.py	2008-04-23 20:32:02 UTC (rev 5072)
@@ -224,6 +224,14 @@
 
     def __getitem__(self, index):
         self._getitem = True
+
+        # If indexing by scalar, check whether we are indexing into
+        # a vector, and then return the corresponding element
+        if N.isscalar(index) and (1 in self.shape):
+            index = [index,index]
+            index[list(self.shape).index(1)] = 0
+            index = tuple(index)
+
         try:
             out = N.ndarray.__getitem__(self, index)
         finally:
@@ -254,7 +262,6 @@
             if (val > 1): truend += 1
         return truend
 
-
     def __mul__(self, other):
         if isinstance(other,(N.ndarray, list, tuple)) :
             # This promotes 1-D vectors to row vectors

Modified: trunk/numpy/core/tests/test_defmatrix.py
===================================================================
--- trunk/numpy/core/tests/test_defmatrix.py	2008-04-23 14:46:29 UTC (rev 5071)
+++ trunk/numpy/core/tests/test_defmatrix.py	2008-04-23 20:32:02 UTC (rev 5072)
@@ -179,6 +179,17 @@
         x[:,1] = y>0.5
         assert_equal(x, [[0,1],[0,0],[0,0]])
 
+    def check_vector_element(self):
+        x = matrix([[1,2,3],[4,5,6]])
+        assert_equal(x[0][0],1)
+        assert_equal(x[0].shape,(1,3))
+        assert_equal(x[:,0].shape,(2,1))
 
+        x = matrix(0)
+        assert_equal(x[0,0],0)
+        assert_equal(x[0],0)
+        assert_equal(x[:,0].shape,x.shape)
+
+
 if __name__ == "__main__":
     NumpyTest().run()



More information about the Numpy-svn mailing list