[Numpy-svn] r5255 - in trunk/numpy/core: src tests

numpy-svn@scip... numpy-svn@scip...
Thu Jun 5 18:27:56 CDT 2008


Author: oliphant
Date: 2008-06-05 18:27:52 -0500 (Thu, 05 Jun 2008)
New Revision: 5255

Modified:
   trunk/numpy/core/src/arrayobject.c
   trunk/numpy/core/src/multiarraymodule.c
   trunk/numpy/core/tests/test_regression.py
Log:
Fix more in ticket #791.

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2008-06-05 17:40:15 UTC (rev 5254)
+++ trunk/numpy/core/src/arrayobject.c	2008-06-05 23:27:52 UTC (rev 5255)
@@ -6641,6 +6641,26 @@
     }
 }
 
+
+static int
+_zerofill(PyArrayObject *ret)
+{
+    intp n;
+
+    if (PyDataType_REFCHK(ret->descr)) {
+        PyObject *zero = PyInt_FromLong(0);
+        PyArray_FillObjectArray(ret, zero);
+        Py_DECREF(zero);
+        if (PyErr_Occurred()) {Py_DECREF(ret); return -1;}
+    }
+    else {
+        n = PyArray_NBYTES(ret);
+        memset(ret->data, 0, n);
+	return 0;
+    }  
+}
+
+
 /* Create a view of a complex array with an equivalent data-type
    except it is real instead of complex.
 */
@@ -6722,29 +6742,26 @@
 array_imag_get(PyArrayObject *self)
 {
     PyArrayObject *ret;
-    PyArray_Descr *type;
 
     if (PyArray_ISCOMPLEX(self)) {
         ret = _get_part(self, 1);
-        return (PyObject *) ret;
     }
     else {
-        type = self->descr;
-        Py_INCREF(type);
-        ret = (PyArrayObject *)PyArray_Zeros(self->nd,
-                                             self->dimensions,
-                                             type,
-                                             PyArray_ISFORTRAN(self));
+        Py_INCREF(self->descr);
+	ret = (PyArrayObject *)PyArray_NewFromDescr(self->ob_type,
+						    self->descr,
+						    self->nd, 
+						    self->dimensions,
+						    NULL, NULL,
+						    PyArray_ISFORTRAN(self),
+						    (PyObject *)self);
+	if (ret == NULL) return NULL;
+
+	if (_zerofill(ret) < 0) return NULL;
+
         ret->flags &= ~WRITEABLE;
-        if (PyArray_CheckExact(self))
-            return (PyObject *)ret;
-        else {
-            PyObject *newret;
-            newret = PyArray_View(ret, NULL, self->ob_type);
-            Py_DECREF(ret);
-            return newret;
-        }
     }
+    return (PyObject *) ret;
 }
 
 static int

Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c	2008-06-05 17:40:15 UTC (rev 5254)
+++ trunk/numpy/core/src/multiarraymodule.c	2008-06-05 23:27:52 UTC (rev 5255)
@@ -1345,7 +1345,7 @@
                                             self->dimensions,
                                             NULL, NULL,
                                             PyArray_ISFORTRAN(self),
-                                            NULL);
+                                            (PyObject *)self);
         if (out == NULL) goto fail;
         outgood = 1;
     }
@@ -5886,7 +5886,6 @@
     return ret;
 }
 
-
 /* steal a reference */
 /* accepts NULL type */
 /*NUMPY_API
@@ -5896,7 +5895,6 @@
 PyArray_Zeros(int nd, intp *dims, PyArray_Descr *type, int fortran)
 {
     PyArrayObject *ret;
-    intp n;
 
     if (!type) type = PyArray_DescrFromType(PyArray_DEFAULT);
     ret = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
@@ -5906,16 +5904,7 @@
                                                 fortran, NULL);
     if (ret == NULL) return NULL;
 
-    if (PyDataType_REFCHK(type)) {
-        PyObject *zero = PyInt_FromLong(0);
-        PyArray_FillObjectArray(ret, zero);
-        Py_DECREF(zero);
-        if (PyErr_Occurred()) {Py_DECREF(ret); return NULL;}
-    }
-    else {
-        n = PyArray_NBYTES(ret);
-        memset(ret->data, 0, n);
-    }
+    if (_zerofill(ret) < 0) return NULL;
     return (PyObject *)ret;
 
 }

Modified: trunk/numpy/core/tests/test_regression.py
===================================================================
--- trunk/numpy/core/tests/test_regression.py	2008-06-05 17:40:15 UTC (rev 5254)
+++ trunk/numpy/core/tests/test_regression.py	2008-06-05 23:27:52 UTC (rev 5255)
@@ -1053,6 +1053,25 @@
         except TypeError:
             pass
 
+    def check_attributes(self, level=rlevel):
+        """Ticket #791
+        """
+        import numpy as np
+        class TestArray(np.ndarray):
+            def __new__(cls, data, info):
+                result = np.array(data)
+                result = result.view(cls)
+                result.info = info
+                return result
+            def __array_finalize__(self, obj):
+                self.info = getattr(obj, 'info', '')
+        dat = TestArray([[1,2,3,4],[5,6,7,8]],'jubba')
+        assert dat.info == 'jubba'
+        assert dat.mean(1).info == 'jubba'
+        assert dat.std(1).info == 'jubba'
+        assert dat.clip(2,7).info == 'jubba'
+        assert dat.imag.info == 'jubba'
+        
 
     def check_recarray_tolist(self, level=rlevel):
         """Ticket #793, changeset r5215



More information about the Numpy-svn mailing list