[Numpy-discussion] extract elements of an array that are contained in another array?
Alan G Isaac
aisaac@american....
Thu Jun 4 09:13:19 CDT 2009
> On Thu, Jun 4, 2009 at 8:23 AM, Alan G Isaac <aisaac@american.edu> wrote:
>> a[(a==b[:,None]).sum(axis=0,dtype=bool)]
On 6/4/2009 8:35 AM josef.pktd@gmail.com apparently wrote:
> If b is large this creates a huge intermediate array
True enough, but one could then use fromiter:
setb = set(b)
itr = (ai for ai in a if ai in setb)
out = np.fromiter(itr, dtype=a.dtype)
I suspect (?) that b would have to be pretty
big relative to a for the repeated testing
to be more costly than sorting a.
Or if a stable order is not important (I don't
recall if the OP specified), one could just
np.intersect1d(a, np.unique(b))
On a different note, I think a name change
is needed for your function. (Compare
intersect1d_nu to see the potential
confusion. And btw, what is the use case
for intersect1d, which gives neither a
set intersection nor a multiset intersection?)
Cheers,
Alan Isaac
