[Numpy-discussion] Do we want scalar casting to behave as it does at the moment?

Nathaniel Smith njs@pobox....
Fri Jan 4 10:54:27 CST 2013


On Fri, Jan 4, 2013 at 4:01 PM, Andrew Collette
<andrew.collette@gmail.com> wrote:
> Hi Olivier,
>
>> A key difference is that with arrays, the dtype is not chosen "just
>> big enough" for your data to fit. Either you set the dtype yourself,
>> or you're using the default inferred dtype (int/float). In both cases
>> you should know what to expect, and it doesn't depend on the actual
>> numeric values (except for the auto int/float distinction).
>
> Yes, certainly; for example, you would get an int32/int64 if you
> simply do "array(4)". What I mean is, when you do "a+b" and b is a
> scalar, I had assumed that the normal array rules for addition apply,
> if you treat the dtype of b as being the smallest precision possible
> which can hold that value.  E.g. 1 (int8) + 42 would treat 42 as an
> int8, and 1 (int8) + 200 would treat 200 as an int16.  If I'm not
> mistaken, this is what happens currently.

Well, that's the thing... there is actually *no* version of numpy
where the "normal rules" apply to scalars. If
  a = np.array([1, 2, 3], dtype=np.uint8)
then in numpy 1.5 and earlier we had
  # Python scalars
  (a / 1).dtype == np.uint8
  (a / 300).dtype == np.uint8
  # Numpy scalars
  (a / np.int_(1)) == np.uint8
  (a / np.int_(300)) == np.uint8
  # Arrays
  (a / [1]).dtype == np.int_
  (a / [300]).dtype == np.int_

In 1.6 we have:
  # Python scalars
  (a / 1).dtype == np.uint8
  (a / 300).dtype == np.uint16
  # Numpy scalars
  (a / np.int_(1)) == np.uint8
  (a / np.int_(300)) == np.uint16
  # Arrays
  (a / [1]).dtype == np.int_
  (a / [1]).dtype == np.int_

In fact in 1.6 there is no assignment of a dtype to '1' which makes
the way 1.6 handles it consistent with the array rules:
  # Ah-hah, it looks like '1' has a uint8 dtype:
  (np.ones(2, dtype=np.uint8) / np.ones(2, dtype=np.uint8)).dtype == np.uint8
  (np.ones(2, dtype=np.uint8) / 1).dtype == np.uint8
  # But wait! No it doesn't!
  (np.ones(2, dtype=np.int8) / np.ones(2, dtype=np.uint8)).dtype == np.int16
  (np.ones(2, dtype=np.int8) / 1).dtype == np.int8
  # Apparently in this case it has an int8 dtype instead.
  (np.ones(2, dtype=np.int8) / np.ones(2, dtype=np.int8)).dtype == np.int8

In 1.5, the special rule for (same-kind) scalars is that we always
cast them to the array's type.
In 1.6, the special rule for (same-kind) scalars is that we cast them
to some type which is a function of the array's type, and the scalar's
value, but not the scalar's type.

This is especially confusing because normally in numpy the *only* way
to get a dtype that is not in the set [np.bool, np.int_, np.float64,
np.complex128, np.object_] (the dtypes produced by np.array(pyobj)) is
to explicitly request it by name. So if you're memory-constrained, a
useful mental model is to think that there are two types of arrays:
your compact ones that use the specific limited-precision type you've
picked (uint8, float32, whichever), and "regular" arrays, which use
machine precision. And all you have to keep track of is the
interaction between these. But in 1.6, as soon as you have a uint8
array, suddenly all the other precisions might spring magically into
being at any moment.

So options:
If we require that new dtypes shouldn't be suddenly introduced then we
have to pick from:
  1) a / 300 silently rolls over the 300 before attempting the
operation (1.5-style)
  2) a / 300 upcasts to machine precision (use the same rules for
arrays and scalars)
  3) a / 300 gives an error (the proposal you don't like)

If we instead treat a Python scalar like 1 as having the smallest
precision dtype that can hold its value, then we have to accept either
  uint8 + 1 -> uint16
or
  int8 + 1 -> int16

Or there's the current code, whose behaviour no-one actually
understands. (And I mean that both figuratively -- it's clearly
confusing enough that people won't be able to remember it well in
practice -- and literally -- even we developers don't know what it
will do without running it to see.)

-n


More information about the NumPy-Discussion mailing list