[Numpy-svn] r4065 - in trunk/numpy/core: src tests

numpy-svn@scip... numpy-svn@scip...
Thu Sep 20 20:22:49 CDT 2007


Author: oliphant
Date: 2007-09-20 20:22:44 -0500 (Thu, 20 Sep 2007)
New Revision: 4065

Modified:
   trunk/numpy/core/src/multiarraymodule.c
   trunk/numpy/core/tests/test_regression.py
Log:
Fix ticket #546: invalid argmax for non-native arrays.

Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c	2007-09-20 21:12:04 UTC (rev 4064)
+++ trunk/numpy/core/src/multiarraymodule.c	2007-09-21 01:22:44 UTC (rev 4065)
@@ -3675,9 +3675,11 @@
         op = ap;
     }
 
+    /* Will get native-byte order contiguous copy. 
+     */
     ap = (PyArrayObject *)\
-        PyArray_ContiguousFromAny((PyObject *)op,
-                                  PyArray_NOTYPE, 1, 0);
+        PyArray_ContiguousFromAny((PyObject *)op, 
+				  op->descr->type_num, 1, 0);
 
     Py_DECREF(op);
     if (ap == NULL) return NULL;
@@ -3693,7 +3695,7 @@
     if (m == 0) {
         PyErr_SetString(MultiArrayError,
                         "attempt to get argmax/argmin "\
-                        "of an empty sequence??");
+                        "of an empty sequence");
         goto fail;
     }
 
@@ -3719,7 +3721,7 @@
     }
 
     NPY_BEGIN_THREADS_DESCR(ap->descr)
-        n = PyArray_SIZE(ap)/m;
+    n = PyArray_SIZE(ap)/m;
     rptr = (intp *)rp->data;
     for (ip = ap->data, i=0; i<n; i++, ip+=elsize*m) {
         arg_func(ip, m, rptr, ap);
@@ -3727,7 +3729,7 @@
     }
     NPY_END_THREADS_DESCR(ap->descr)
 
-        Py_DECREF(ap);
+    Py_DECREF(ap);
     if (copyret) {
         PyArrayObject *obj;
         obj = (PyArrayObject *)rp->base;

Modified: trunk/numpy/core/tests/test_regression.py
===================================================================
--- trunk/numpy/core/tests/test_regression.py	2007-09-20 21:12:04 UTC (rev 4064)
+++ trunk/numpy/core/tests/test_regression.py	2007-09-21 01:22:44 UTC (rev 4065)
@@ -697,6 +697,11 @@
         x = N.array(['a']*32)
         assert_array_equal(x.argsort(kind='m'), N.arange(32))
 
+    def check_argmax_byteorder(self, level=rlevel):
+        """Ticket #546"""
+        a = arange(3, dtype='>f')
+        assert a[a.argmax()] == a.max()
+
     def check_numeric_random(self, level=rlevel):
         """Ticket #552"""
         from numpy.oldnumeric.random_array import randint
@@ -720,5 +725,6 @@
         """Ticket #572"""
         N.lib.place(1,1,1)
 
+
 if __name__ == "__main__":
     NumpyTest().run()



More information about the Numpy-svn mailing list