[Scipy-svn] r3938 - in trunk/scipy/sparse: . tests

scipy-svn@scip... scipy-svn@scip...
Thu Feb 14 02:35:00 CST 2008


Author: wnbell
Date: 2008-02-14 02:34:57 -0600 (Thu, 14 Feb 2008)
New Revision: 3938

Modified:
   trunk/scipy/sparse/csr.py
   trunk/scipy/sparse/tests/test_base.py
Log:
added better CSR slicing


Modified: trunk/scipy/sparse/csr.py
===================================================================
--- trunk/scipy/sparse/csr.py	2008-02-14 06:00:10 UTC (rev 3937)
+++ trunk/scipy/sparse/csr.py	2008-02-14 08:34:57 UTC (rev 3938)
@@ -8,10 +8,11 @@
 import numpy
 from numpy import array, matrix, asarray, asmatrix, zeros, rank, intc, \
         empty, hstack, isscalar, ndarray, shape, searchsorted, where, \
-        concatenate, deprecate
+        concatenate, deprecate, arange, ones
 
 from base import spmatrix, isspmatrix
-from sparsetools import csr_tocsc, csr_tobsr, csr_count_blocks
+from sparsetools import csr_tocsc, csr_tobsr, csr_count_blocks, \
+        get_csr_submatrix
 from sputils import upcast, to_native, isdense, isshape, getdtype, \
         isscalarlike, isintlike
 
@@ -182,6 +183,189 @@
         return (x[0],x[1])
 
 
+    def __getitem__(self, key):
+        def asindices(x):
+            try:
+                x = asarray(x,dtype='intc')
+            except:
+                raise IndexError('invalid index')
+            else:
+                return x
+
+        def extractor(indices,N):
+            """Return a sparse matrix P so that P*self implements
+            slicing of the form self[[1,2,3],:]
+            """
+            indices = asindices(indices)
+            
+            max_indx = indices.max()
+
+            if max_indx >  N:
+                raise ValueError('index (%d) out of range' % max_indx)
+
+            min_indx = indices.min()
+            if min_indx < -N:
+                raise ValueError('index (%d) out of range' % (N + min_indx))
+            
+            if min_indx < 0:
+                indices = indices.copy()
+                indices[indices < 0] += N
+
+            indptr  = arange(len(indices) + 1, dtype='intc')
+            data    = ones(len(indices), dtype=self.dtype)
+            shape   = (len(indices),N)
+
+            return csr_matrix( (data,indices,indptr), shape=shape)
+            
+
+        if isinstance(key, tuple):
+            row = key[0]
+            col = key[1]
+           
+            #TODO implement CSR[ [1,2,3], X ] with sparse matmat
+            #TODO make use of sorted indices
+
+            if isintlike(row):
+                #[1,??]
+                if isintlike(col):
+                    return self._get_single_element(row, col) #[i,j]
+                elif isinstance(col, slice):
+                    return self._get_row_slice(row, col)      #[i,1:2]
+                else:
+                    P = extractor(col,self.shape[1]).T        #[i,[1,2]]
+                    return self[row,:]*P
+                    
+            elif isinstance(row, slice):
+                #[1:2,??]
+                if isintlike(col) or isinstance(col, slice):
+                    return self._get_submatrix(row, col)      #[1:2,j]
+                else:
+                    P = extractor(col,self.shape[1]).T        #[1:2,[1,2]]
+                    return self[row,:]*P
+            else:    
+                #[[1,2],??]
+                if isintlike(col) or isinstance(col,slice):
+                    P = extractor(row, self.shape[0])         
+                    return (P*self)[:,col]                   #[[1,2],j] or [[1,2],1:2] 
+                else:
+                    row = asindices(row)                     #[[1,2],[1,2]]
+                    col = asindices(col)
+                    if len(row) != len(col):
+                        raise ValueError('number of row and column indices differ')
+                    val = []
+                    for i,j in zip(row,col):
+                        val.append(self._get_single_element(i,j))
+                    return asmatrix(val)
+
+
+        elif isintlike(key) or isinstance(key,slice):
+            return self[key,:]                                #[i] or [1:2]
+        else:
+            return self[asindices(key),:]                     #[[1,2]]
+    
+
+    def _get_single_element(self,row,col):
+        M, N = self.shape
+        if (row < 0):
+            row += M
+        if (col < 0):
+            col += N
+        if not (0<=row<M) or not (0<=col<N):
+            raise IndexError, "index out of bounds"
+        
+        start = self.indptr[row]
+        end   = self.indptr[row+1]
+        indxs = where(col == self.indices[start:end])[0]
+
+        num_matches = len(indxs)
+
+        if num_matches == 0:
+            # entry does not appear in the matrix
+            return self.dtype.type(0)
+        elif num_matches == 1:
+            return self.data[start:end][indxs[0]]
+        else:
+            raise ValueError('nonzero entry (%d,%d) occurs more than once' % (row,col) )
+
+    def _get_row_slice(self, i, cslice ):
+        """Returns a copy of self[i, cslice] 
+        """
+        if i < 0:
+            i += self.shape[0]
+
+        if i < 0:
+            raise ValueError('index (%d) out of range' % i ) 
+
+        start, stop, stride = cslice.indices(self.shape[1])
+
+        if stride != 1:
+            raise ValueError, "slicing with step != 1 not supported"
+        if stop <= start:
+            raise ValueError, "slice width must be >= 1"
+
+        #TODO make [i,:] faster
+        #TODO implement [i,x:y:z]
+
+        indices = []
+
+        for ind in xrange(self.indptr[i], self.indptr[i+1]):
+            if self.indices[ind] >= start and self.indices[ind] < stop:
+                indices.append(ind)
+
+        index  = self.indices[indices] - start
+        data   = self.data[indices]
+        indptr = numpy.array([0, len(indices)])
+        return csr_matrix( (data, index, indptr), shape=(1, stop-start) )
+
+    def _get_submatrix( self, row_slice, col_slice ):
+        """Return a submatrix of this matrix (new matrix is created)."""
+
+        M,N = self.shape
+
+        def process_slice( sl, num ):
+            if isinstance( sl, slice ):
+                i0, i1 = sl.start, sl.stop
+                if i0 is None:
+                    i0 = 0
+                elif i0 < 0:
+                    i0 = num + i0
+
+                if i1 is None:
+                    i1 = num
+                elif i1 < 0:
+                    i1 = num + i1
+
+                return i0, i1
+
+            elif isscalar( sl ):
+                if sl < 0:
+                    sl += num
+
+                return sl, sl + 1
+
+            else:
+                raise TypeError('expected slice or scalar')
+
+        def check_bounds( i0, i1, num ):
+            if not (0<=i0<num) or not (0<i1<=num) or not (i0<i1):
+                raise IndexError,\
+                      "index out of bounds: 0<=%d<%d, 0<=%d<%d, %d<%d" %\
+                      (i0, num, i1, num, i0, i1)
+
+        i0, i1 = process_slice( row_slice, M )
+        j0, j1 = process_slice( col_slice, N )
+        check_bounds( i0, i1, M )
+        check_bounds( j0, j1, N )
+
+        indptr, indices, data = get_csr_submatrix( M, N, \
+                self.indptr, self.indices, self.data, i0, i1, j0, j1 )
+
+        shape =  (i1 - i0, j1 - j0)
+
+        return self.__class__( (data,indices,indptr), shape=shape )
+
+
+
 from sputils import _isinstance
 
 def isspmatrix_csr(x):

Modified: trunk/scipy/sparse/tests/test_base.py
===================================================================
--- trunk/scipy/sparse/tests/test_base.py	2008-02-14 06:00:10 UTC (rev 3937)
+++ trunk/scipy/sparse/tests/test_base.py	2008-02-14 08:34:57 UTC (rev 3938)
@@ -628,9 +628,9 @@
     def test_get_slices(self):
         B = asmatrix(arange(50.).reshape(5,10))
         A = self.spmatrix(B)
-        assert_array_equal(B[2:5,0:3], A[2:5,0:3].todense())
-        assert_array_equal(B[1:,:-1], A[1:,:-1].todense())
-        assert_array_equal(B[:-1,1:], A[:-1,1:].todense())
+        assert_array_equal(A[2:5,0:3].todense(), B[2:5,0:3])
+        assert_array_equal(A[1:,:-1].todense(),  B[1:,:-1]) 
+        assert_array_equal(A[:-1,1:].todense(),  B[:-1,1:])
 
         # Now test slicing when a column contains only zeros
         E = matrix([[1, 0, 1], [4, 0, 0], [0, 0, 0], [0, 0, 1]])
@@ -852,6 +852,66 @@
         assert_array_equal(asp.todense(),bsp.todense())
 
 
+    def test_fancy_slicing(self):
+        #TODO add this to csc_matrix
+        B = asmatrix(arange(50).reshape(5,10))
+        A = csr_matrix( B )
+
+        # [i,j]
+        assert_equal(A[2,3],B[2,3])
+        assert_equal(A[-1,8],B[-1,8])
+        assert_equal(A[-1,-2],B[-1,-2])
+
+        # [i,1:2]
+        assert_equal(A[2,:].todense(),B[2,:])
+        assert_equal(A[2,5:-2].todense(),B[2,5:-2])
+       
+        # [i,[1,2]]
+        assert_equal(A[3,[1,3]].todense(),B[3,[1,3]])
+        assert_equal(A[-1,[2,-5]].todense(),B[-1,[2,-5]])
+
+        # [1:2,j]
+        assert_equal(A[:,2].todense(),B[:,2])
+        assert_equal(A[3:4,9].todense(),B[3:4,9])
+        assert_equal(A[1:4,-5].todense(),B[1:4,-5])
+
+        # [1:2,[1,2]]
+        assert_equal(A[:,[2,8,3,-1]].todense(),B[:,[2,8,3,-1]])
+        assert_equal(A[3:4,[9]].todense(),B[3:4,[9]])
+        assert_equal(A[1:4,[-1,-5]].todense(),B[1:4,[-1,-5]])
+
+        # [[1,2],j]
+        assert_equal(A[[1,3],3].todense(),B[[1,3],3])
+        assert_equal(A[[2,-5],-4].todense(),B[[2,-5],-4])
+        
+        # [[1,2],1:2]
+        assert_equal(A[[1,3],:].todense(),B[[1,3],:])
+        assert_equal(A[[2,-5],8:-1].todense(),B[[2,-5],8:-1])
+    
+        # [[1,2],[1,2]]
+        assert_equal(A[[1,3],[2,4]],B[[1,3],[2,4]])
+        assert_equal(A[[-1,-3],[2,-4]],B[[-1,-3],[2,-4]])
+
+        # [i]
+        assert_equal(A[1].todense(),B[1])
+        assert_equal(A[-2].todense(),B[-2])
+
+        # [1:2]
+        assert_equal(A[1:4].todense(),B[1:4])
+        assert_equal(A[1:-2].todense(),B[1:-2])
+
+        # [[1,2]]
+        assert_equal(A[[1,3]].todense(),B[[1,3]])
+        assert_equal(A[[-1,-3]].todense(),B[[-1,-3]])
+
+        # [[1,2],:][:,[1,2]]
+        assert_equal(A[[1,3],:][:,[2,4]].todense(),    B[[1,3],:][:,[2,4]]    )
+        assert_equal(A[[-1,-3],:][:,[2,-4]].todense(), B[[-1,-3],:][:,[2,-4]] )
+
+        # [:,[1,2]][[1,2],:]
+        assert_equal(A[:,[1,3]][[2,4],:].todense(),    B[:,[1,3]][[2,4],:]    )
+        assert_equal(A[:,[-1,-3]][[2,-4],:].todense(), B[:,[-1,-3]][[2,-4],:] )
+
 class TestCSC(_TestCommon, _TestGetSet, _TestSolve,
         _TestInplaceArithmetic, _TestArithmetic, _TestMatvecOutput,
         _TestHorizSlicing, _TestVertSlicing, _TestBothSlicing,



More information about the Scipy-svn mailing list