[Numpy-svn] r5891 - in branches/1.2.x/numpy/core: src tests

numpy-svn@scip... numpy-svn@scip...
Thu Oct 2 15:37:01 CDT 2008


Author: oliphant
Date: 2008-10-02 15:37:00 -0500 (Thu, 02 Oct 2008)
New Revision: 5891

Modified:
   branches/1.2.x/numpy/core/src/ufuncobject.c
   branches/1.2.x/numpy/core/tests/test_umath.py
Log:
BUG: Backport fix to object arrays in r5889 to 1.2.x branch.

Modified: branches/1.2.x/numpy/core/src/ufuncobject.c
===================================================================
--- branches/1.2.x/numpy/core/src/ufuncobject.c	2008-10-02 20:33:57 UTC (rev 5890)
+++ branches/1.2.x/numpy/core/src/ufuncobject.c	2008-10-02 20:37:00 UTC (rev 5891)
@@ -1427,12 +1427,15 @@
      * FAIL with NotImplemented if the other object has
      * the __r<op>__ method and has __array_priority__ as
      * an attribute (signalling it can handle ndarray's)
-     * and is not already an ndarray
+     * and is not already an ndarray or a subtype of the same type.
      */
     if ((arg_types[1] == PyArray_OBJECT) &&                         \
         (loop->ufunc->nin==2) && (loop->ufunc->nout == 1)) {
         PyObject *_obj = PyTuple_GET_ITEM(args, 1);
-        if (!PyArray_CheckExact(_obj) &&                        \
+        if (!PyArray_CheckExact(_obj) &&
+	    /* If both are same subtype of object arrays, then proceed */
+	    !(_obj->ob_type == (PyTuple_GET_ITEM(args, 0))->ob_type) &&   \
+
             PyObject_HasAttrString(_obj, "__array_priority__") && \
             _has_reflected_op(_obj, loop->ufunc->name)) {
             loop->notimplemented = 1;

Modified: branches/1.2.x/numpy/core/tests/test_umath.py
===================================================================
--- branches/1.2.x/numpy/core/tests/test_umath.py	2008-10-02 20:33:57 UTC (rev 5890)
+++ branches/1.2.x/numpy/core/tests/test_umath.py	2008-10-02 20:37:00 UTC (rev 5891)
@@ -278,6 +278,16 @@
         assert_equal(add.nout, 1)
         assert_equal(add.identity, 0)
 
+class TestSubclass(TestCase):
+    def test_subclass_op(self):
+        class simple(np.ndarray):
+            def __new__(subtype, shape):
+                self = np.ndarray.__new__(subtype, shape, dtype=object)
+                self.fill(0)
+                return self
+        a = simple((3,4))
+        assert_equal(a+a, a)
+
 def _check_branch_cut(f, x0, dx, re_sign=1, im_sign=-1, sig_zero_ok=False,
                       dtype=np.complex):
     """



More information about the Numpy-svn mailing list