[Numpy-discussion] Distance Matrix speed
Alan G Isaac
aisaac at american.edu
Sun Jun 18 23:30:12 CDT 2006
On Sun, 18 Jun 2006, Tim Hochberg apparently wrote:
> Alan G Isaac wrote:
>> On Sun, 18 Jun 2006, Sebastian Beca apparently wrote:
>>> def dist():
>>> d = zeros([N, C], dtype=float)
>>> if N < C: for i in range(N):
>>> xy = A[i] - B d[i,:] = sqrt(sum(xy**2, axis=1))
>>> return d
>>> else:
>>> for j in range(C):
>>> xy = A - B[j] d[:,j] = sqrt(sum(xy**2, axis=1))
>>> return d
>> But that is 50% slower than Johannes's version:
>> def dist_loehner1():
>> d = A[:, newaxis, :] - B[newaxis, :, :]
>> d = sqrt((d**2).sum(axis=2))
>> return d
> Are you sure about that? I just ran it through timeit, using Sebastian's
> array sizes and I get Sebastian's version being 150% faster. This
> could well be cache size dependant, so may vary from box to box, but I'd
> expect Sebastian's current version to scale better in general.
No, I'm not sure.
Script attached bottom.
Most recent output follows:
for reasons I have not determined,
it doesn't match my previous runs ...
Alan
>>> execfile(r'c:\temp\temp.py')
dist_beca : 3.042277
dist_loehner1: 3.170026
#################################
#THE SCRIPT
import sys
sys.path.append("c:\\temp")
import numpy
from numpy import *
import timeit
K = 10
C = 2500
N = 3 # One could switch around C and N now.
A = numpy.random.random( [N, K] )
B = numpy.random.random( [C, K] )
# beca
def dist_beca():
d = zeros([N, C], dtype=float)
if N < C:
for i in range(N):
xy = A[i] - B
d[i,:] = sqrt(sum(xy**2, axis=1))
return d
else:
for j in range(C):
xy = A - B[j]
d[:,j] = sqrt(sum(xy**2, axis=1))
return d
#loehnert
def dist_loehner1():
# drawback: memory usage temporarily doubled
# solution see below
d = A[:, newaxis, :] - B[newaxis, :, :]
# written as 3 expressions for more clarity
d = sqrt((d**2).sum(axis=2))
return d
if __name__ == "__main__":
t1 = timeit.Timer('dist_beca()', 'from temp import dist_beca').timeit(100)
t8 = timeit.Timer('dist_loehner1()', 'from temp import dist_loehner1').timeit(100)
fmt="%-10s:\t"+"%10.6f"
print fmt%('dist_beca', t1)
print fmt%('dist_loehner1', t8)
More information about the Numpy-discussion
mailing list