[Numpy-svn] r6358 - in trunk/numpy/core: . src tests

numpy-svn@scip... numpy-svn@scip...
Wed Feb 11 22:22:07 CST 2009


Author: oliphant
Date: 2009-02-11 22:22:03 -0600 (Wed, 11 Feb 2009)
New Revision: 6358

Modified:
   trunk/numpy/core/_internal.py
   trunk/numpy/core/src/arrayobject.c
   trunk/numpy/core/tests/test_numerictypes.py
Log:
Add multiple-field access by making a copy of the array and filling with the selected fields.

Modified: trunk/numpy/core/_internal.py
===================================================================
--- trunk/numpy/core/_internal.py	2009-02-11 01:52:37 UTC (rev 6357)
+++ trunk/numpy/core/_internal.py	2009-02-12 04:22:03 UTC (rev 6358)
@@ -292,3 +292,22 @@
                 raise ValueError, "unknown field name: %s" % (name,)
         return tuple(list(order) + nameslist)
     raise ValueError, "unsupported order value: %s" % (order,)
+
+# Given an array with fields and a sequence of field names
+# construct a new array with just those fields copied over
+def _index_fields(ary, fields):
+    from multiarray import empty, dtype
+    dt = ary.dtype
+    new_dtype = [(name, dt[name]) for name in dt.names if name in fields]
+    if ary.flags.f_contiguous:
+        order = 'F'
+    else:
+        order = 'C'
+
+    newarray = empty(ary.shape, dtype=new_dtype, order=order) 
+   
+    for name in fields:
+        newarray[name] = ary[name]
+
+    return newarray
+    

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2009-02-11 01:52:37 UTC (rev 6357)
+++ trunk/numpy/core/src/arrayobject.c	2009-02-12 04:22:03 UTC (rev 6358)
@@ -2827,10 +2827,10 @@
     int nd, fancy;
     PyArrayObject *other;
     PyArrayMapIterObject *mit;
+    PyObject *obj;
 
     if (PyString_Check(op) || PyUnicode_Check(op)) {
         if (self->descr->names) {
-            PyObject *obj;
             obj = PyDict_GetItem(self->descr->fields, op);
             if (obj != NULL) {
                 PyArray_Descr *descr;
@@ -2852,6 +2852,34 @@
         return NULL;
     }
 
+    /* Check for multiple field access 
+     */
+    if (self->descr->names && PySequence_Check(op) && !PyTuple_Check(op)) {
+	int seqlen, i;
+	seqlen = PySequence_Size(op);
+	for (i=0; i<seqlen; i++) {
+	    obj = PySequence_GetItem(op, i);
+	    if (!PyString_Check(obj) && !PyUnicode_Check(obj)) {
+		Py_DECREF(obj);
+		break;
+	    }
+	    Py_DECREF(obj);
+	}
+	/* extract multiple fields if all elements in sequence
+	   are either string or unicode (i.e. no break occurred). 
+	*/
+	fancy = ((seqlen > 0) && (i == seqlen));
+	if (fancy) { 
+	    PyObject *_numpy_internal;
+	    _numpy_internal = PyImport_ImportModule("numpy.core._internal");
+	    if (_numpy_internal == NULL) return NULL;
+	    obj = PyObject_CallMethod(_numpy_internal, "_index_fields",
+				      "OO", self, op);
+	    Py_DECREF(_numpy_internal);
+	    return obj;
+	}
+    }
+
     if (op == Py_Ellipsis) {
 	Py_INCREF(self);
 	return (PyObject *)self;

Modified: trunk/numpy/core/tests/test_numerictypes.py
===================================================================
--- trunk/numpy/core/tests/test_numerictypes.py	2009-02-11 01:52:37 UTC (rev 6357)
+++ trunk/numpy/core/tests/test_numerictypes.py	2009-02-12 04:22:03 UTC (rev 6358)
@@ -353,6 +353,16 @@
         res = np.find_common_type(['u8','i8','i8'],['f8'])
         assert(res == 'f8')
 
+class TestMultipleFields(TestCase):
+    def setUp(self):
+        self.ary = np.array([(1,2,3,4),(5,6,7,8)], dtype='i4,f4,i2,c8')
+    def _bad_call(self):
+        return self.ary['f0','f1']
+    def test_no_tuple(self):
+        self.failUnlessRaises(ValueError, self._bad_call)
+    def test_return(self):
+        res = self.ary[['f0','f2']].tolist()
+        assert(res == [(1,3), (5,7)])        
 
 if __name__ == "__main__":
     run_module_suite()



More information about the Numpy-svn mailing list