[Numpy-discussion] by axis iterator

Gabriel Gellner ggellner@uoguelph...
Wed Nov 12 12:47:33 CST 2008

Something I use a lot is a little generator that iterates over a ndarray by a
given axis. I was wondering if this is already built-in to numpy (and not
using the apply_along_axis which I find ugly) and if not would there be
interest in adding it?

the function is just:

def by_axis(nobj, axis=0):
    index_set = [slice(None)]*len(ndobj.shape)
        for i in xrange(ndobj.shape[axis]):
            index_set[axis] = i
            yield ndobj[index_set]

and can be just like

>>> [sum(x) for x in by_axis(a, 1)]
>>> for col in by_axis(a, 1):
...     print col

I use it when porting R code that uses a lot of apply like logic. I know most
numpy functions have the axis argument built in, but when writing my own
functions I find this a real time saver.

Anyway, if someone can show be a better way I would be overjoyed, or if people
like this I can make a ticket on Trac.


More information about the Numpy-discussion mailing list