[Numpy-discussion] use index array of len n to select columns of n x m array

josef.pktd@gmai... josef.pktd@gmai...
Fri Aug 6 05:29:25 CDT 2010


On Fri, Aug 6, 2010 at 6:01 AM, Martin Spacek <numpy@mspacek.mm.st> wrote:
> Keith Goodman wrote:
>  > Here's one way:
>  >
>  >>> a.flat[i + a.shape[1] * np.arange(a.shape[0])]
>  >     array([0, 3, 5, 6, 9])
>
>
> I'm afraid I made my example a little too simple. In retrospect, what I really
> want is to be able to use a 2D index array "i", like this:
>
>  >>> a = np.array([[ 0,  1,  2,  3],
>                   [ 4,  5,  6,  7],
>                   [ 8,  9, 10, 11],
>                   [12, 13, 14, 15],
>                   [16, 17, 18, 19]])
>  >>> i = np.array([[2, 1],
>                   [3, 1],
>                   [1, 1],
>                   [0, 0],
>                   [3, 1]])
>  >>> foo(a, i)
> array([[ 2,  1],
>        [ 7,  5],
>        [ 9,  9],
>        [12, 12],
>        [19, 17]])
>
> I think the flat iterator indexing suggestion is about the only thing that'll
> work. Here's the function I've pretty much settled on:
>
> def rowtake(a, i):
>     """For each row in a, return values according to column indices in the
>     corresponding row in i. Returned shape == i.shape"""
>     assert a.ndim == 2
>     assert i.ndim <= 2
>     if i.ndim == 1:
>         return a.flat[i + a.shape[1] * np.arange(a.shape[0])]
>     else: # i.ndim == 2
>         return a.flat[i + a.shape[1] * np.vstack(np.arange(a.shape[0]))]


I still find broadcasting easier to read, even if it might be a bit slower

>>> a[np.arange(5)[:,None], i]
array([[ 2,  1],
       [ 7,  5],
       [ 9,  9],
       [12, 12],
       [19, 17]])

Josef


>
> This is about half as fast as my Cython function, but the Cython function is
> limited to fixed dtypes and ndim:
>
> @cython.boundscheck(False)
> @cython.wraparound(False)
> @cython.cdivision(True)
> def rowtake_cy(np.ndarray[np.int32_t, ndim=2] a,
>                np.ndarray[np.int32_t, ndim=2] i):
>     """For each row in a, return values according to column indices in the
>     corresponding row in i. Returned shape == i.shape"""
>
>     cdef Py_ssize_t nrows, ncols, rowi, coli
>     cdef np.ndarray[np.int32_t, ndim=2] out
>
>     nrows = i.shape[0]
>     ncols = i.shape[1] # num cols to take from a for each row
>     assert a.shape[0] == nrows
>     assert i.max() < a.shape[1]
>     out = np.empty((nrows, ncols), dtype=np.int32)
>
>     for rowi in range(nrows):
>         for coli in range(ncols):
>             out[rowi, coli] = a[rowi, i[rowi, coli]]
>
>     return out
>
> Cheers,
>
> Martin
> _______________________________________________
> NumPy-Discussion mailing list
> NumPy-Discussion@scipy.org
> http://mail.scipy.org/mailman/listinfo/numpy-discussion
>


More information about the NumPy-Discussion mailing list