[Numpy-svn] r3074 - in trunk/numpy: core/code_generators core/src lib

numpy-svn at scipy.org numpy-svn at scipy.org
Sat Aug 26 02:50:57 CDT 2006


Author: oliphant
Date: 2006-08-26 02:50:53 -0500 (Sat, 26 Aug 2006)
New Revision: 3074

Modified:
   trunk/numpy/core/code_generators/multiarray_api_order.txt
   trunk/numpy/core/src/arraymethods.c
   trunk/numpy/core/src/arrayobject.c
   trunk/numpy/lib/function_base.py
Log:
Fix broadcast-copy on fancy set-item.

Modified: trunk/numpy/core/code_generators/multiarray_api_order.txt
===================================================================
--- trunk/numpy/core/code_generators/multiarray_api_order.txt	2006-08-26 00:49:30 UTC (rev 3073)
+++ trunk/numpy/core/code_generators/multiarray_api_order.txt	2006-08-26 07:50:53 UTC (rev 3074)
@@ -75,5 +75,6 @@
 PyArray_TypeNumFromName
 PyArray_ClipmodeConverter
 PyArray_OutputConverter
+PyArray_BroadcastToShape
 _PyArray_SigintHandler
-_PyArray_GetSigintBuf
\ No newline at end of file
+_PyArray_GetSigintBuf

Modified: trunk/numpy/core/src/arraymethods.c
===================================================================
--- trunk/numpy/core/src/arraymethods.c	2006-08-26 00:49:30 UTC (rev 3073)
+++ trunk/numpy/core/src/arraymethods.c	2006-08-26 07:50:53 UTC (rev 3074)
@@ -2,8 +2,8 @@
 /* Should only be used if x is known to be an nd-array */
 #define _ARET(x) PyArray_Return((PyArrayObject *)(x))
 
-static char doc_take[] = "a.take(indices, axis=None).  Selects the elements "\
-	"in indices from array a along the given axis.";
+static char doc_take[] = "a.take(indices, axis=None, out=None, mode='raise')."\
+        "Selects the elements in indices from array a along the given axis.";
 
 static PyObject *
 array_take(PyArrayObject *self, PyObject *args, PyObject *kwds)
@@ -41,9 +41,9 @@
 	return Py_None;
 }
 
-static char doc_put[] = "a.put(values, indices, mode) sets a.flat[n] = v[n] "\
-	"for each n in indices. v can be scalar or shorter than indices, "\
-	"will repeat.";
+static char doc_put[] = "a.put(values, indices, mode) sets a.flat[n] = "\
+        "values[n] for\n" "each n in indices. v can be scalar or shorter "\
+        "than indices,\n" "and it will repeat.";
 
 static PyObject *
 array_put(PyArrayObject *self, PyObject *args, PyObject *kwds)

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2006-08-26 00:49:30 UTC (rev 3073)
+++ trunk/numpy/core/src/arrayobject.c	2006-08-26 07:50:53 UTC (rev 3074)
@@ -2399,8 +2399,11 @@
                 }
         }
 
-
-        if ((it = (PyArrayIterObject *)PyArray_IterNew(arr))==NULL) {
+        /* Be sure values array is "broadcastable" 
+           to shape of mit->dimensions, mit->nd */
+        
+        if ((it = (PyArrayIterObject *)\
+             PyArray_BroadcastToShape(arr, mit->dimensions, mit->nd))==NULL) {
                 Py_DECREF(arr);
                 return -1;
         }
@@ -2408,7 +2411,6 @@
         index = mit->size;
         swap = (PyArray_ISNOTSWAPPED(mit->ait->ao) != \
                 (PyArray_ISNOTSWAPPED(arr)));
-
         copyswap = PyArray_DESCR(arr)->f->copyswap;
         PyArray_MapIterReset(mit);
         /* Need to decref hasobject arrays */
@@ -2421,8 +2423,6 @@
                         copyswap(mit->dataptr, NULL, swap, arr);
                         PyArray_MapIterNext(mit);
                         PyArray_ITER_NEXT(it);
-                        if (it->index == it->size)
-                                PyArray_ITER_RESET(it);
                 }
                 Py_DECREF(arr);
                 Py_DECREF(it);
@@ -2434,8 +2434,6 @@
                         copyswap(mit->dataptr, NULL, swap, arr);
                 PyArray_MapIterNext(mit);
                 PyArray_ITER_NEXT(it);
-                if (it->index == it->size)
-                        PyArray_ITER_RESET(it);
         }
         Py_DECREF(arr);
         Py_DECREF(it);
@@ -2703,7 +2701,8 @@
                 if (oned) {
                         PyArrayIterObject *it;
                         PyObject *rval;
-                        it = (PyArrayIterObject *)PyArray_IterNew((PyObject *)self);
+                        it = (PyArrayIterObject *)\
+                                PyArray_IterNew((PyObject *)self);
                         if (it == NULL) {Py_DECREF(mit); return NULL;}
                         rval = iter_subscript(it, mit->indexobj);
                         Py_DECREF(it);
@@ -8592,28 +8591,94 @@
 
         nd = ao->nd;
         PyArray_UpdateFlags(ao, CONTIGUOUS);
-        it->contiguous = 0;
         if PyArray_ISCONTIGUOUS(ao) it->contiguous = 1;
+        else it->contiguous = 0;
         Py_INCREF(ao);
         it->ao = ao;
         it->size = PyArray_SIZE(ao);
         it->nd_m1 = nd - 1;
         it->factors[nd-1] = 1;
         for (i=0; i < nd; i++) {
-                it->dims_m1[i] = it->ao->dimensions[i] - 1;
-                it->strides[i] = it->ao->strides[i];
+                it->dims_m1[i] = ao->dimensions[i] - 1;
+                it->strides[i] = ao->strides[i];
                 it->backstrides[i] = it->strides[i] *   \
                         it->dims_m1[i];
                 if (i > 0)
                         it->factors[nd-i-1] = it->factors[nd-i] *       \
+                                ao->dimensions[nd-i];
+        }
+        PyArray_ITER_RESET(it);
+
+        return (PyObject *)it;
+}
+
+/*MULTIARRAY_API
+  Get Iterator broadcast to a particular shape
+ */
+static PyObject *
+PyArray_BroadcastToShape(PyObject *obj, intp *dims, int nd)
+{
+        PyArrayIterObject *it;
+        int i, diff, j, compat, k;
+        PyArrayObject *ao = (PyArrayObject *)obj;
+ 
+        if (ao->nd > nd) goto err;
+        compat = 1;
+        diff = j = nd - ao->nd;
+        for (i=0; i<ao->nd; i++, j++) {
+                if (ao->dimensions[i] == 1) continue;
+                if (ao->dimensions[i] != dims[j]) {
+                        compat = 0;
+                        break;
+                }
+        }
+        if (!compat) goto err;
+
+        it = (PyArrayIterObject *)_pya_malloc(sizeof(PyArrayIterObject));
+        PyObject_Init((PyObject *)it, &PyArrayIter_Type);
+        
+        if (it == NULL)
+                return NULL;
+
+        PyArray_UpdateFlags(ao, CONTIGUOUS);
+        if PyArray_ISCONTIGUOUS(ao) it->contiguous = 1;
+        else it->contiguous = 0;
+        Py_INCREF(ao);
+        it->ao = ao;
+        it->size = PyArray_MultiplyList(dims, nd);
+        it->nd_m1 = nd - 1;
+        it->factors[nd-1] = 1;
+        for (i=0; i < nd; i++) {
+                it->dims_m1[i] = dims[i] - 1;
+                k = i - diff;
+                if ((k < 0) ||
+                    ao->dimensions[k] != dims[i]) {
+                        it->contiguous = 0;
+                        it->strides[i] = 0;
+                }
+                else {
+                        it->strides[i] = ao->strides[i];
+                }
+                it->backstrides[i] = it->strides[i] *   \
+                        it->dims_m1[i];
+                if (i > 0)
+                        it->factors[nd-i-1] = it->factors[nd-i] *       \
                                 it->ao->dimensions[nd-i];
         }
         PyArray_ITER_RESET(it);
 
         return (PyObject *)it;
+
+ err:
+        PyErr_SetString(PyExc_ValueError, "array is not broadcastable to "\
+                        "correct shape");
+        return NULL;
 }
 
 
+
+
+
 /*OBJECT_API
  Get Iterator that iterates over all but one axis (don't use this with
  PyArray_ITER_GOTO1D).  The axis will be over-written if negative.

Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2006-08-26 00:49:30 UTC (rev 3073)
+++ trunk/numpy/lib/function_base.py	2006-08-26 07:50:53 UTC (rev 3074)
@@ -1106,7 +1106,22 @@
     """Return a new array with values inserted along the given axis
     before the given indices
 
-    If axis is None, then ravel the array first. 
+    If axis is None, then ravel the array first.
+
+    The obj argument can be an integer, a slice, or a sequence of
+    integers.
+
+    Example:
+    >>> a = array([[1,2,3],
+                   [4,5,6],
+                   [7,8,9]])
+
+    >>> insertinto(a, [1,2], [[4],[5]], axis=0)
+    array([[1,2,3],
+           [4,4,4],
+           [4,5,6],
+           [5,5,5],
+           [7,8,9])
     """
     arr = asarray(arr)
     ndim = arr.ndim    
@@ -1139,28 +1154,32 @@
     elif isinstance(obj, slice):
         # turn it into a range object
         obj = arange(*obj.indices(N),**{'dtype':intp})
-    
-    # default behavior
-    # FIXME: this is too slow
-    obj = array(obj, dtype=intp, copy=0, ndmin=1)
-    try:
-        if len(values) != len(obj):
-            raise TypeError
-    except TypeError:
-        values = [values]*len(obj)
-    new = arr
-    k = 0
-    for item, val in zip(obj, values):
-        new = insertinto(new, item+k, val, axis=axis)
+
+    # get two sets of indices
+    #  one is the indices which will hold the new stuff
+    #  two is the indices where arr will be copied over
+
+    obj = asarray(obj, dtype=intp)
+    numnew = len(obj)
+    index1 = obj + arange(numnew)
+    index2 = setdiff1d(arange(numnew+N),index1)
+    newshape[axis] += numnew
+    new = empty(newshape, arr.dtype, arr.flags.fnc)
+    slobj2 = [slice(None)]*ndim
+    slobj[axis] = index1
+    slobj2[axis] = index2
+    new[slobj] = values
+    new[slobj2] = arr
+        
     return new
 
-def appendonto(arr, obj, axis=None):
+def appendonto(arr, values, axis=None):
     """Append to the end of an array along axis (ravel first if None)
     """
     arr = asarray(arr)
     if axis is None:
         if arr.ndim != 1:
             arr = arr.ravel()
-        obj = ravel(obj)
+        values = ravel(values)
         axis = 0
-    return concatenate((arr, obj), axis=axis)
+    return concatenate((arr, values), axis=axis)



More information about the Numpy-svn mailing list