[Numpy-discussion] Slicing slower than matrix multiplication?

Jasper van de Gronde th.v.d.gronde@hccnet...
Mon Dec 14 11:27:08 CST 2009


Bruce Southey wrote:
>> So far this is the fastest code I've got:
>> ------------------------------------------------------------------------
>> import numpy as np
>>
>> nmax = 100
>>
>> def minover(Xi,S):
>>       P,N = Xi.shape
>>       SXi = Xi.copy()
>>       for i in xrange(0,P):
>>           SXi[i] *= S[i]
>>       SXi2 = np.dot(SXi,SXi.T)
>>       SXiSXi2divN = np.concatenate((SXi,SXi2),axis=1)/N
>>       w = np.random.standard_normal((N))
>>       E = np.dot(SXi,w)
>>       wE = np.concatenate((w,E))
>>       for s in xrange(0,nmax*P):
>>           mu = wE[N:].argmin()
>>           wE += SXiSXi2divN[mu]
>>           # E' = dot(SXi,w')
>>           #    = dot(SXi,w + SXi[mu,:]/N)
>>           #    = dot(SXi,w) + dot(SXi,SXi[mu,:])/N
>>           #    = E + dot(SXi,SXi.T)[:,mu]/N
>>           #    = E + dot(SXi,SXi.T)[mu,:]/N
>>       return wE[:N]
>> ------------------------------------------------------------------------
>>
>> I am particularly interested in cleaning up the initialization part, but
>> any suggestions for improving the overall performance are of course
>> appreciated.
>>
>>    
> What is Xi and S?
> I think that your SXi is just:
> SXi=Xi*S

Sort of, it's actually (Xi.T*S).T, now that I think of it... I'll see if 
that is any faster. And if there is a neater way of doing it I'd love to 
hear about it.

> But really I do not understand what you are actually trying to do. As 
> previously indicated, some times simplifying an algorithm can make it 
> computationally slower.

It was hardly simplified, this was the original function body:
     P,N = Xi.shape
     SXi = Xi.copy()
     for i in xrange(0,P):
         SXi[i] *= S[i]
     w = np.random.standard_normal((N))
     for s in xrange(0,nmax*P):
         E = np.dot(SXi,w)
         mu = E.argmin()
         w += SXi[mu]/N
     return w

As you can see it's basically some basic linear algebra (which reduces 
the time complexity from about O(n^3) to O(n^2)), plus some less nice 
tweaks to avoid the high Python overhead.


More information about the NumPy-Discussion mailing list