[NumPy-Tickets] [NumPy] #2163: BLAS matrix product (dot) never used for ndim > 2 (tensordot does not use BLAS)
NumPy Trac
numpy-tickets@scipy....
Fri Jun 15 11:09:52 CDT 2012
#2163: BLAS matrix product (dot) never used for ndim > 2 (tensordot does not use
BLAS)
------------------------------------------------+---------------------------
Reporter: thatistosay | Owner: somebody
Type: enhancement | Status: new
Priority: normal | Milestone: Unscheduled
Component: numpy.core | Version: 1.6.1
Keywords: dot, blas, tensordot, optimization |
------------------------------------------------+---------------------------
dotblas_matrixproduct() contains the comment "This function doesn't handle
dimensions greater than 2" and calls PyArray_MatrixProduct2() for these
cases. This means BLAS is never used for calls to dot() with arguments of
ndim>2!! In more detail...
If I want to contract a pair of tensor indices for ndim=3 such that
{{{
A = np.rand(d,D,D); B = np.rand(d,D,D)
AB[s,i,t,j] == sum(A[s,i,:], B[t,:,j])
}}}
then currently, although it can be done in a single line
{{{
res = sp.dot(A,B)
}}}
it can often be done much faster with explicit (python!) loops
{{{
res = np.zeros((d,d,D,D))
for s in xrange(d):
for t in xrange(d):
np.dot(A[s], B[t], out=res[s,t])
res = np.rollaxis(res, 2, 1)
}}}
..assuming dot() is using optimized BLAS for ndim=2, and the dimensions
are large enough so that calling BLAS is worth it.
In general, reproducing the behaviour of dot() for ndim>2 is just a matter
of calling GEMM in loops as above and then calling rollaxis() once.
I therefore propose doing this within _blasdot.c as far as possible (to
eliminate the use of python loops) so that ndim>2 dot(), and tensordot(),
can benefit from BLAS.
Some comparisons of the two methods above (attached script):
{{{
dtype=complex128
AB[s,i,t,j] = sum(A[s,i,:], B[t,:,j])
A.shape = (16, 512, 512); B.shape = (16, 512, 512)
looping over 2D dot() vs. 3D dot(): 24% (about 4 times faster)
A.shape = (20, 64, 64); B.shape = (20, 64, 64)
looping over 2D dot() vs. 3D dot(): 35%
A.shape = (20, 48, 32); B.shape = (20, 32, 48)
looping over 2D dot() vs. 3D dot(): 45%
A.shape = (32, 32, 16); B.shape = (32, 16, 32)
looping over 2D dot() vs. 3D dot(): 82%
A.shape = (64, 10, 8); B.shape = (64, 8, 10)
looping over 2D dot() vs. 3D dot(): 158% (slow python loops..)
}}}
(this was on a 4-core i7 system using ATLAS under heavy load)
--
Ticket URL: <http://projects.scipy.org/numpy/ticket/2163>
NumPy <http://projects.scipy.org/numpy>
My example project
More information about the NumPy-Tickets
mailing list