[Numpy-discussion] use index array of len n to select columns of n x m array

Martin Spacek numpy@mspacek.mm...
Fri Aug 6 05:01:10 CDT 2010


Keith Goodman wrote:
 > Here's one way:
 >
 >>> a.flat[i + a.shape[1] * np.arange(a.shape[0])]
 >     array([0, 3, 5, 6, 9])


I'm afraid I made my example a little too simple. In retrospect, what I really 
want is to be able to use a 2D index array "i", like this:

 >>> a = np.array([[ 0,  1,  2,  3],
                   [ 4,  5,  6,  7],
                   [ 8,  9, 10, 11],
                   [12, 13, 14, 15],
                   [16, 17, 18, 19]])
 >>> i = np.array([[2, 1],
                   [3, 1],
                   [1, 1],
                   [0, 0],
                   [3, 1]])
 >>> foo(a, i)
array([[ 2,  1],
        [ 7,  5],
        [ 9,  9],
        [12, 12],
        [19, 17]])

I think the flat iterator indexing suggestion is about the only thing that'll 
work. Here's the function I've pretty much settled on:

def rowtake(a, i):
     """For each row in a, return values according to column indices in the
     corresponding row in i. Returned shape == i.shape"""
     assert a.ndim == 2
     assert i.ndim <= 2
     if i.ndim == 1:
         return a.flat[i + a.shape[1] * np.arange(a.shape[0])]
     else: # i.ndim == 2
         return a.flat[i + a.shape[1] * np.vstack(np.arange(a.shape[0]))]

This is about half as fast as my Cython function, but the Cython function is 
limited to fixed dtypes and ndim:

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def rowtake_cy(np.ndarray[np.int32_t, ndim=2] a,
                np.ndarray[np.int32_t, ndim=2] i):
     """For each row in a, return values according to column indices in the
     corresponding row in i. Returned shape == i.shape"""

     cdef Py_ssize_t nrows, ncols, rowi, coli
     cdef np.ndarray[np.int32_t, ndim=2] out

     nrows = i.shape[0]
     ncols = i.shape[1] # num cols to take from a for each row
     assert a.shape[0] == nrows
     assert i.max() < a.shape[1]
     out = np.empty((nrows, ncols), dtype=np.int32)

     for rowi in range(nrows):
         for coli in range(ncols):
             out[rowi, coli] = a[rowi, i[rowi, coli]]

     return out

Cheers,

Martin


More information about the NumPy-Discussion mailing list