[Numpy-discussion] Fastest distance matrix calc

Bill Baxter wbaxter@gmail....
Mon Apr 16 22:54:13 CDT 2007


Here's a bunch of dist matrix implementations and their timings.
The upshot is that for most purposes this seems to be the best or at
least not too far off (basically the cookbook solution Kier posted)

def dist2hd(x,y):
    """Generate a 'coordinate' of the solution at a time"""
    d = npy.zeros((x.shape[0],y.shape[0]),dtype=x.dtype)
    for i in xrange(x.shape[1]):
        diff2 = x[:,i,None] - y[:,i]
        diff2 **= 2
        d += diff2
    npy.sqrt(d,d)
    return d

The only place where it's far from the best is for a small number of
points (~10) with high dimensionality (~100), which does come up in
machine learning contexts.  For those cases this does much better
(factor of :

def dist2b3(x,y):
    d = npy.dot(x,y.T)
    d *= -2.0
    d += (x*x).sum(1)[:,None]
    d += (y*y).sum(1)
    # Rounding errors occasionally cause negative entries in d
    d[d<0] = 0
    # in place sqrt
    npy.sqrt(d,d)
    return d

So given that, the obvious solution (if you don't want to delve into
non-numpy solutions) is  to use a hybrid that just switches between
the two.  Not sure what the proper switch is since it seems kind of
complicated, and probably depends some on cache specifics.  But just
switching based on the dimension of the points seems to be pretty
effective:

def dist2hy(x,y):
    if x.shape[1]<5:
        d = npy.zeros((x.shape[0],y.shape[0]),dtype=x.dtype)
        for i in xrange(x.shape[1]):
            diff2 = x[:,i,None] - y[:,i]
            diff2 **= 2
            d += diff2
        npy.sqrt(d,d)
        return d

    else:
        d = npy.dot(x,y.T)
        d *= -2.0
        d += (x*x).sum(1)[:,None]
        d += (y*y).sum(1)
        # Rounding errors occasionally cause negative entries in d
        d[d<0] = 0
        # in place sqrt
        npy.sqrt(d,d)
        return d

All of this assumes 'C' contiguous data.  All bets are off if you have
non-contiguous or 'F' ordered data.  And maybe if x and y have very
different numbers of points.


--bb



On 4/17/07, Keir Mierle <mierle@gmail.com> wrote:
> On 4/13/07, Timothy Hochberg <tim.hochberg@ieee.org> wrote:
> > On 4/13/07, Bill Baxter <wbaxter@gmail.com> wrote:
> > > I think someone posted some timings about this before but I don't recall.
>
> http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/498246
>
> [snip]
> > I'm going to go out on a limb and contend, without running any timings, that
> > for large M and N, a solution using a for loop will beat either of those.
> > For example (untested):
> >
> >  results = empty([M, N], float)
> > # You could be fancy and swap axes depending on which array is larger, but
> > # I'll leave that for someone else
> > for i, v in enumerate(x):
> >     results[i] = sqrt(sum((v-y)**2, axis=-1))
> >  Or something like that. The reason that I suspect this will be faster is
> > that it has better locality, completely finishing a computation on a
> > relatively small working set before moving onto the next one. The one liners
> > have to pull the potentially large MxN array into the processor repeatedly.
>
> In my experience, it is indeed the case that the for loop version is
> faster. The fastest of the three versions offered in the above url is
> the last:
>
> from numpy import mat, zeros, newaxis
> def calcDistanceMatrixFastEuclidean2(nDimPoints):
>     nDimPoints = array(nDimPoints)
>     n,m = nDimPoints.shape
>     delta = zeros((n,n),'d')
>     for d in xrange(m):
>         data = nDimPoints[:,d]
>         delta += (data - data[:,newaxis])**2
>     return sqrt(delta)
>
> This is easily extended to two different nDimPoints matricies.
>
> Cheers,
> Keir
> _______________________________________________
> Numpy-discussion mailing list
> Numpy-discussion@scipy.org
> http://projects.scipy.org/mailman/listinfo/numpy-discussion
>
-------------- next part --------------
A non-text attachment was scrubbed...
Name: dist2perf.out
Type: application/octet-stream
Size: 6214 bytes
Desc: not available
Url : http://projects.scipy.org/pipermail/numpy-discussion/attachments/20070417/a7d1446d/attachment.obj 


More information about the Numpy-discussion mailing list