[Numpy-svn] r4822 - trunk/numpy/core/src

numpy-svn@scip... numpy-svn@scip...
Mon Feb 25 16:51:24 CST 2008


Author: oliphant
Date: 2008-02-25 16:51:21 -0600 (Mon, 25 Feb 2008)
New Revision: 4822

Modified:
   trunk/numpy/core/src/arrayobject.c
   trunk/numpy/core/src/scalartypes.inc.src
Log:
Allow numpy scalars to be indexed in limited ways, but not be iterable.  Fix consistency bug with [...] indexing and remove useless check and allow 0-d boolean arrays to work as masks for scalars.

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2008-02-24 18:01:43 UTC (rev 4821)
+++ trunk/numpy/core/src/arrayobject.c	2008-02-25 22:51:21 UTC (rev 4822)
@@ -2490,7 +2490,7 @@
     return 0;
 }
 
-int
+static int
 count_new_axes_0d(PyObject *tuple)
 {
     int i, argument_count;
@@ -2706,14 +2706,12 @@
         return NULL;
     }
 
+    if (op == Py_Ellipsis) {
+	Py_INCREF(self);
+	return (PyObject *)self;
+    }
+
     if (self->nd == 0) {
-        if (op == Py_Ellipsis) {
-            /* XXX: This leads to a small inconsistency
-               XXX: with the nd>0 case where (x[...] is x)
-               XXX: is false for nd>0 case. */
-            Py_INCREF(self);
-            return (PyObject *)self;
-        }
         if (op == Py_None)
             return add_new_axes_0d(self, 1);
         if (PyTuple_Check(op)) {
@@ -2726,9 +2724,8 @@
             return add_new_axes_0d(self, nd);
         }
         /* Allow Boolean mask selection also */
-        if (PyBool_Check(op) || PyArray_IsScalar(op, Bool) ||
-            (PyArray_Check(op) && (PyArray_DIMS(op)==0) &&
-             PyArray_ISBOOL(op))) {
+        if ((PyArray_Check(op) && (PyArray_DIMS(op)==0) &&
+	     PyArray_ISBOOL(op))) {
             if (PyObject_IsTrue(op)) {
                 Py_INCREF(self);
                 return (PyObject *)self;
@@ -3051,7 +3048,8 @@
         if ((op == Py_Ellipsis) || PyString_Check(op) || PyUnicode_Check(op))
             noellipses = FALSE;
         else if (PyBool_Check(op) || PyArray_IsScalar(op, Bool) ||
-                 (PyArray_Check(op) && (PyArray_DIMS(op)==0)))
+                 (PyArray_Check(op) && (PyArray_DIMS(op)==0) &&
+		  PyArray_ISBOOL(op)))
             noellipses = FALSE;
         else if (PySequence_Check(op)) {
             int n, i;
@@ -9212,7 +9210,6 @@
     char *dptr;
     int size;
     PyObject *obj = NULL;
-    int swap;
     PyArray_CopySwapFunc *copyswap;
 
     if (ind == Py_Ellipsis) {

Modified: trunk/numpy/core/src/scalartypes.inc.src
===================================================================
--- trunk/numpy/core/src/scalartypes.inc.src	2008-02-24 18:01:43 UTC (rev 4821)
+++ trunk/numpy/core/src/scalartypes.inc.src	2008-02-25 22:51:21 UTC (rev 4822)
@@ -2373,6 +2373,54 @@
         0,                                        /* tp_flags */
 };
 
+
+static PyObject *
+add_new_axes_0d(PyArrayObject *,  int);
+
+static int 
+count_new_axes_0d(PyObject *);
+
+static PyObject *
+gen_arrtype_subscript(PyObject *self, PyObject *key)
+{
+	/* Only [...], [...,<???>], [<???>, ...],
+	   is allowed for indexing a scalar
+	   
+	   These return a new N-d array with a copy of
+	   the data where N is the number of None's in <???>.
+
+	 */
+	PyObject *res, *ret;
+	int N;
+
+	if (key == Py_Ellipsis || key == Py_None ||
+	    PyTuple_Check(key)) {
+		res = PyArray_FromScalar(self, NULL);
+	}
+	else {
+		PyErr_SetString(PyExc_IndexError, 
+				"invalid index to scalar variable.");
+		return NULL;
+	}
+
+	if (key == Py_Ellipsis) 
+		return res;
+
+	if (key == Py_None) {
+		ret = add_new_axes_0d((PyArrayObject *)res, 1);
+		Py_DECREF(res);
+		return ret;
+	}
+	/* Must be a Tuple */
+	
+	N = count_new_axes_0d(key);
+	if (N < 0) return NULL;
+	ret = add_new_axes_0d((PyArrayObject *)res, N);
+	Py_DECREF(res);
+	return ret;
+}
+
+
 /**begin repeat
 #name=bool, string, unicode, void#
 #NAME=Bool, String, Unicode, Void#
@@ -2418,6 +2466,14 @@
 #undef _THIS_SIZE
 /**end repeat**/
 
+
+static PyMappingMethods gentype_as_mapping = {
+	NULL,
+        (binaryfunc)gen_arrtype_subscript,
+        NULL
+};
+
+
 /**begin repeat
 #NAME=CFloat, CDouble, CLongDouble#
 #name=complex*3#
@@ -2475,7 +2531,6 @@
 /**end repeat**/
 
 
-
 static PyNumberMethods longdoubletype_as_number;
 static PyNumberMethods clongdoubletype_as_number;
 
@@ -2486,6 +2541,7 @@
         PyGenericArrType_Type.tp_dealloc = (destructor)gentype_dealloc;
         PyGenericArrType_Type.tp_as_number = &gentype_as_number;
         PyGenericArrType_Type.tp_as_buffer = &gentype_as_buffer;
+        PyGenericArrType_Type.tp_as_mapping = &gentype_as_mapping;
         PyGenericArrType_Type.tp_flags = BASEFLAGS;
         PyGenericArrType_Type.tp_methods = gentype_methods;
         PyGenericArrType_Type.tp_getset = gentype_getsets;



More information about the Numpy-svn mailing list