[Numpy-discussion] Multiplying Python float to numpy.array of objects works but fails with a numpy.float64, numpy Bug?

Keith Goodman kwgoodman@gmail....
Tue Jun 2 09:09:31 CDT 2009


On Tue, Jun 2, 2009 at 1:42 AM, Sebastian Walter
<sebastian.walter@gmail.com> wrote:
> Hello,
> Multiplying a Python float to a numpy.array of objects works flawlessly
> but not with a numpy.float64 .
> I tried  numpy version '1.0.4' on a 32 bit Linux and  '1.2.1' on a 64
> bit Linux: both raise the same exception.
>
> Is this a (known) bug?
>
> ---------------------- test.py ------------------------------------
> from numpy import *
>
> class adouble:
>        def __init__(self,x):
>                self.x = x
>        def __mul__(self,rhs):
>                if isinstance(rhs,adouble):
>                        return adouble(self.x * rhs.x)
>                else:
>                        return adouble(self.x * rhs)
>        def __str__(self):
>                return str(self.x)
>
> x = adouble(3.)
> y = adouble(2.)
> u = array([adouble(3.), adouble(5.)])
> v = array([adouble(2.), adouble(7.)])
> z = array([2.,3.])
>
> print x * y              # ok
> print u * v              # ok
> print u * z              # ok
> print u * 3.             # ok
> print u * z[0]           # _NOT_ OK!
> print u * float64(3.)    # _NOT_ OK!
>
>
>
> ---------------------- output   ---------------------------------
> walter@wronski$ python test.py
> 6.0
> [6.0 35.0]
> [6.0 15.0]
> [9.0 15.0]
> Traceback (most recent call last):
>  File "test.py", line 24, in <module>
>    print u * z[0]   # _NOT_ OK!
> TypeError: unsupported operand type(s) for *: 'numpy.ndarray' and
> 'numpy.float64'

Try adding __rmul__ = __mul__ like below:

from numpy import *

class adouble:
    def __init__(self,x):
        self.x = x
    def __mul__(self,rhs):
        if isinstance(rhs,adouble):
            return adouble(self.x * rhs.x)
        else:
            return adouble(self.x * rhs)
    __rmul__ = __mul__
    def __str__(self):
        return str(self.x)

def test():
    x = adouble(3.)
    print 3 * x

Output:

>> test.test()
9.0


More information about the Numpy-discussion mailing list