[Numpy-discussion] performance of scipy: potential inefficiency in logsumexp and sampling from multinomial
per freem
perfreem@gmail....
Mon Oct 12 16:25:39 CDT 2009
hi all,
i have a piece of code that relies heavily on sampling from
multinomial distributions and using their results to compute log
probabilities. my code makes heavy use of 'multinomial' from scipy,
and of 'logsumexp'.
my code is unusually slow, and profiling it with Python's "cPickle"
module reveals that most of the time is spent in the following
functions:
479.524 0.000 code.py:211(my_func)
122.682 0.000
/Library/Python/2.5/site-packages/scipy/maxentropy/maxentutils.py:27(logsumexp)
40.645 0.000
/Library/Python/2.5/site-packages/numpy/core/numeric.py:180(asarray)
20.374 0.000 {method 'max' of 'numpy.ndarray' objects}
(the first column represents cumulative time, the second is percall time.)
my code (listed as 'my_func' above) essentially computes a list of log
probabilities, exponentiates them and renormalizes them (using
'logsumexp') and then samples from a multinomial distribution using
those probabilities as a parameter. i then check to see which object
came up true from the multinomial sample. here's a sketch of the code:
def my_func(my_list, n_items)
final_list = []
for n in xrange(n_items):
prob = my_dict[(my_list(n), n)]
final_list.append(prob)
final_list = final_list - logsumexp(final_list)
sample = multinomial(1, exp(final_list))
sample_index = list(sampled_reassignment).index(1)
return sample_index
the list 'my_list' usually has around 3 to 5 elements in it, and
'my_dict' has about 500-1000 keys.
this function gets called about 1.5 million times in my code, and it
takes about 5 minutes, which seems very long relative to these
operations. (i'd like to scale this up to a case where the function is
called about 10-120 million times.)
are there known efficiency issues with logsumexp? it seems like it
should be a very cheap operation. also, 'multinomial' ought to be
relatively cheap, i believe. does anyone have any ideas on how this
can be optimized? any input will be greatly appreciated. i am also
open to using cython if that is likely to make a significant
improvement in this case.
also, what is likely to be the origin of the call to "asarray"? (i am
not explicitly calling that function, it must be indirectly via some
other function.)
thanks very much.
More information about the NumPy-Discussion
mailing list