[SciPy-dev] Enhancement proposal for generic.rvs method in scipy.stats.distributions

Per.Brodtkorb@f... Per.Brodtkorb@f...
Thu Aug 14 06:12:00 CDT 2008

I would propose that the rvs method compute the common shape of the
inputs (i.e., shape, location, scale and

the size information provided) according to numpy broadcasting rules.

I find this feature practical and is similar to what matlab does.


Currently, the random number generators of the 1D distributions only
allows scalar shape, location and scale parameters as input.

The size of the output is only determined by the 'size' input variable. 




PS: One solution could be to redefine the rvs method in the rv_continous
as follows:


def rvs(self,*args,**kwds):



        args, loc, scale = self.__fix_loc_scale(args, loc, scale)

        cond = logical_and(self._argcheck(*args),(scale >= 0))

        if not all(cond):

            raise ValueError, "Domain error in arguments."


        cshape = common_shape(zeros(size),loc,scale,*args)

        #self._size = product(cshape)

        self._size = cshape

        vals = self._rvs(*args)

        return vals * scale + loc





def common_shape(*varargin):

    ''' Return the common shape of a sequency of arrays


    An error is raised if some of the arrays do not conform 

    to the common shape according to the broadcasting rules in numpy.



       >>> import pylab

       >>> A = pylab.rand(4,1)

       >>> B = 2

       >>> C = pylab.rand(1,5)

       >>> common_shape(A,B,C)

       (4, 5)




    varargout = atleast_1d(*varargin)


    if len(varargin)<2:

        return tuple(varargout.shape)


    args_shape = [arg.shape for arg in varargout] #map(shape, varargout)

    ndims = map(len, args_shape)

    ndim = max(ndims)

    Np = len(varargin)


    all_shapes = ones((Np, ndim),dtype=int)

    for ix, Nt in enumerate(ndims):

        all_shapes[ix, 0:Nt] = args_shape[ix]


    ndims = atleast_1d(ndims)

    if any(ndims == 0):

        all_shapes[ndims == 0, :] = 0


    comn_shape = numpy.max(all_shapes, axis=0)


    arrays_do_not_conform2common_shape =


    if any(arrays_do_not_conform2common_shape):

        raise ValueError('Non-scalar input arguments do not match in
shape according to numpy broadcasting rules')


    return tuple(comn_shape)

-------------- next part --------------
An HTML attachment was scrubbed...
URL: http://projects.scipy.org/pipermail/scipy-dev/attachments/20080814/885d434c/attachment-0001.html 

More information about the Scipy-dev mailing list