[SciPy-dev] Enhancement proposal for generic.rvs method in scipy.stats.distributions
Per.Brodtkorb@f...
Per.Brodtkorb@f...
Thu Aug 14 06:12:00 CDT 2008
I would propose that the rvs method compute the common shape of the
inputs (i.e., shape, location, scale and
the size information provided) according to numpy broadcasting rules.
I find this feature practical and is similar to what matlab does.
Currently, the random number generators of the 1D distributions only
allows scalar shape, location and scale parameters as input.
The size of the output is only determined by the 'size' input variable.
pab
PS: One solution could be to redefine the rvs method in the rv_continous
as follows:
def rvs(self,*args,**kwds):
loc,scale,size=map(kwds.get,['loc','scale','size'],[None,None,1])
args, loc, scale = self.__fix_loc_scale(args, loc, scale)
cond = logical_and(self._argcheck(*args),(scale >= 0))
if not all(cond):
raise ValueError, "Domain error in arguments."
cshape = common_shape(zeros(size),loc,scale,*args)
#self._size = product(cshape)
self._size = cshape
vals = self._rvs(*args)
return vals * scale + loc
where
def common_shape(*varargin):
''' Return the common shape of a sequency of arrays
An error is raised if some of the arrays do not conform
to the common shape according to the broadcasting rules in numpy.
Example:
>>> import pylab
>>> A = pylab.rand(4,1)
>>> B = 2
>>> C = pylab.rand(1,5)
>>> common_shape(A,B,C)
(4, 5)
'''
varargout = atleast_1d(*varargin)
if len(varargin)<2:
return tuple(varargout.shape)
args_shape = [arg.shape for arg in varargout] #map(shape, varargout)
ndims = map(len, args_shape)
ndim = max(ndims)
Np = len(varargin)
all_shapes = ones((Np, ndim),dtype=int)
for ix, Nt in enumerate(ndims):
all_shapes[ix, 0:Nt] = args_shape[ix]
ndims = atleast_1d(ndims)
if any(ndims == 0):
all_shapes[ndims == 0, :] = 0
comn_shape = numpy.max(all_shapes, axis=0)
arrays_do_not_conform2common_shape =
any(logical_and(all_shapes!=comn_shape[newaxis,...],
all_shapes!=1),axis=1)
if any(arrays_do_not_conform2common_shape):
raise ValueError('Non-scalar input arguments do not match in
shape according to numpy broadcasting rules')
return tuple(comn_shape)
-------------- next part --------------
An HTML attachment was scrubbed...
URL: http://projects.scipy.org/pipermail/scipy-dev/attachments/20080814/885d434c/attachment-0001.html
More information about the Scipy-dev
mailing list