[Numpy-discussion] argmin & min on ndarrays
Anne Archibald
peridot.faceted@gmail....
Tue Mar 4 17:00:36 CST 2008
On 04/03/2008, Pierre GM <pgmdevlist@gmail.com> wrote:
> All,
> Let a & b be two ndarrays of the same shape. I'm trying to find the elements
> of b that correspond to the minima of a along an arbitrary axis.
> The problem is trivial when axis=None or when a.ndim=2, but I'm getting
> confused with higher dimensions: I came to the following solution that looks
> rather ugly, and I'd need some ideas to simplify it
>
> >>>a=numpy.arange(24).reshape(2,3,4)
> >>>axis=-1
> >>>b = numpy.rollaxis(a,axis,0)[a.argmin(axis)][tuple([0]*(a.ndim-1))]
> >>>numpy.all(b, a.min(axis))
> True
>
> Thanks a lot in advance for any suggestions.
I couldn't find any nice way to make indexing do what you want, but
the function choose() can be persuaded to do it. Unfortunately it will
only choose along the first axis, so some transpose jiggery-pokery is
necessary:
def pick_argmin(a,b,axis):
assert a.shape == b.shape
t = range(len(b.shape))
i = t[axis]
del t[axis]
t = [i] + t
a = a.transpose(t)
b = b.transpose(t)
return N.choose(N.argmin(a,axis=0),b)
I did find a not-nice way to do what you want. The problem is that
numpy's fancy indexing is so general, it won't let you simply pick and
choose along one axis, you have to pick and choose along all axes. So
what you do is use indices() to generate arrays that index all the
*other* axes appropriately, and then use the argmin array to index the
axis you're interested in:
In [39]: c = N.indices((2,4))
In [40]: b[c[0],N.argmin(a,axis=1),c[1]]
Out[40]:
array([[-0.70659942, -0.997249 , -0.20028296, -0.05171191],
[-1.28886394, -1.0610526 , -1.07193295, 0.05356948]])
In [42]: c[0]
Out[42]:
array([[0, 0, 0, 0],
[1, 1, 1, 1]])
In [43]: c[1]
Out[43]:
array([[0, 1, 2, 3],
[0, 1, 2, 3]])
Not only would this require similar jiggery-pokery, it creates the
potentially very large intermediate array c. I'd stick with choose().
A third option would be to transpose() and reshape() a and b down to
two dimensions, then reshape() the result back to the right shape.
More multiaxis jiggery-pokery, and the reshape()s may end up copying
the arrays.
Finally, you can always just write a python loop (over all axes except
the one of interest) using ndenumerate() and one-dimensional argmin().
If the dimension you're argmin()ing over is very large, the cost of
the python loop may be negligible.
Anne
P.S. feel free to use pick_argmin however you like, though error
handling would probably be a good idea... -A
More information about the Numpy-discussion
mailing list