# [SciPy-dev] Enhancement proposal for generic.rvs method in scipy.stats.distributions

Per.Brodtkorb@f... Per.Brodtkorb@f...
Thu Aug 14 06:12:00 CDT 2008

```I would propose that the rvs method compute the common shape of the
inputs (i.e., shape, location, scale and

the size information provided) according to numpy broadcasting rules.

I find this feature practical and is similar to what matlab does.

Currently, the random number generators of the 1D distributions only
allows scalar shape, location and scale parameters as input.

The size of the output is only determined by the 'size' input variable.

pab

PS: One solution could be to redefine the rvs method in the rv_continous
as follows:

def rvs(self,*args,**kwds):

loc,scale,size=map(kwds.get,['loc','scale','size'],[None,None,1])

args, loc, scale = self.__fix_loc_scale(args, loc, scale)

cond = logical_and(self._argcheck(*args),(scale >= 0))

if not all(cond):

raise ValueError, "Domain error in arguments."

cshape = common_shape(zeros(size),loc,scale,*args)

#self._size = product(cshape)

self._size = cshape

vals = self._rvs(*args)

return vals * scale + loc

where

def common_shape(*varargin):

''' Return the common shape of a sequency of arrays

An error is raised if some of the arrays do not conform

to the common shape according to the broadcasting rules in numpy.

Example:

>>> import pylab

>>> A = pylab.rand(4,1)

>>> B = 2

>>> C = pylab.rand(1,5)

>>> common_shape(A,B,C)

(4, 5)

'''

varargout = atleast_1d(*varargin)

if len(varargin)<2:

return tuple(varargout.shape)

args_shape = [arg.shape for arg in varargout] #map(shape, varargout)

ndims = map(len, args_shape)

ndim = max(ndims)

Np = len(varargin)

all_shapes = ones((Np, ndim),dtype=int)

for ix, Nt in enumerate(ndims):

all_shapes[ix, 0:Nt] = args_shape[ix]

ndims = atleast_1d(ndims)

if any(ndims == 0):

all_shapes[ndims == 0, :] = 0

comn_shape = numpy.max(all_shapes, axis=0)

arrays_do_not_conform2common_shape =
any(logical_and(all_shapes!=comn_shape[newaxis,...],
all_shapes!=1),axis=1)

if any(arrays_do_not_conform2common_shape):

raise ValueError('Non-scalar input arguments do not match in
shape according to numpy broadcasting rules')

return tuple(comn_shape)

-------------- next part --------------
An HTML attachment was scrubbed...
URL: http://projects.scipy.org/pipermail/scipy-dev/attachments/20080814/885d434c/attachment-0001.html
```