[Numpy-svn] r8374 - in trunk: doc/release numpy/core/src/multiarray numpy/core/tests numpy/matrixlib/tests

numpy-svn@scip... numpy-svn@scip...
Fri Apr 30 02:06:02 CDT 2010


Author: ptvirtan
Date: 2010-04-30 02:06:02 -0500 (Fri, 30 Apr 2010)
New Revision: 8374

Modified:
   trunk/doc/release/2.0.0-notes.rst
   trunk/numpy/core/src/multiarray/methods.c
   trunk/numpy/core/tests/test_multiarray.py
   trunk/numpy/matrixlib/tests/test_defmatrix.py
Log:
ENH: core: add .dot() method to ndarrays; a.dot(b) == np.dot(a, b)

Modified: trunk/doc/release/2.0.0-notes.rst
===================================================================
--- trunk/doc/release/2.0.0-notes.rst	2010-04-29 23:57:39 UTC (rev 8373)
+++ trunk/doc/release/2.0.0-notes.rst	2010-04-30 07:06:02 UTC (rev 8374)
@@ -35,3 +35,16 @@
 
     >>> import warnings
     >>> warnings.simplefilter("ignore", np.ComplexWarning)
+
+Dot method for ndarrays
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Ndarrays now have the dot product also as a method, which allows writing
+chains of matrix products as
+
+    >>> a.dot(b).dot(c)
+
+instead of the longer alternative
+
+    >>> np.dot(a, np.dot(b, c))
+

Modified: trunk/numpy/core/src/multiarray/methods.c
===================================================================
--- trunk/numpy/core/src/multiarray/methods.c	2010-04-29 23:57:39 UTC (rev 8373)
+++ trunk/numpy/core/src/multiarray/methods.c	2010-04-30 07:06:02 UTC (rev 8374)
@@ -1793,6 +1793,29 @@
 
 
 static PyObject *
+array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
+{
+    PyObject *b;
+    static PyObject *numpycore = NULL;
+
+    if (!PyArg_ParseTuple(args, "O", &b)) {
+        return NULL;
+    }
+
+    /* Since blas-dot is exposed only on the Python side, we need to grab it
+     * from there */
+    if (numpycore == NULL) {
+        numpycore = PyImport_ImportModule("numpy.core");
+        if (numpycore == NULL) {
+            return NULL;
+        }
+    }
+
+    return PyObject_CallMethod(numpycore, "dot", "OO", self, b);
+}
+
+
+static PyObject *
 array_any(PyArrayObject *self, PyObject *args, PyObject *kwds)
 {
     int axis = MAX_DIMS;
@@ -2192,6 +2215,9 @@
     {"diagonal",
         (PyCFunction)array_diagonal,
         METH_VARARGS | METH_KEYWORDS, NULL},
+    {"dot",
+        (PyCFunction)array_dot,
+        METH_VARARGS, NULL},
     {"fill",
         (PyCFunction)array_fill,
         METH_VARARGS, NULL},

Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py	2010-04-29 23:57:39 UTC (rev 8373)
+++ trunk/numpy/core/tests/test_multiarray.py	2010-04-30 07:06:02 UTC (rev 8374)
@@ -541,7 +541,14 @@
         assert_equal(x1.flatten('F'), y1f)
         assert_equal(x1.flatten('F'), x1.T.flatten())
 
+    def test_dot(self):
+        a = np.array([[1, 0], [0, 1]])
+        b = np.array([[0, 1], [1, 0]])
+        c = np.array([[9, 1], [1, -9]])
 
+        assert_equal(np.dot(a, b), a.dot(b))
+        assert_equal(np.dot(np.dot(a, b), c), a.dot(b).dot(c))
+
 class TestSubscripting(TestCase):
     def test_test_zero_rank(self):
         x = array([1,2,3])

Modified: trunk/numpy/matrixlib/tests/test_defmatrix.py
===================================================================
--- trunk/numpy/matrixlib/tests/test_defmatrix.py	2010-04-29 23:57:39 UTC (rev 8373)
+++ trunk/numpy/matrixlib/tests/test_defmatrix.py	2010-04-30 07:06:02 UTC (rev 8374)
@@ -254,7 +254,8 @@
             'compress' : ([1],),
             'repeat' : (1,),
             'reshape' : (1,),
-            'swapaxes' : (0,0)
+            'swapaxes' : (0,0),
+            'dot': np.array([1.0]),
             }
         excluded_methods = [
             'argmin', 'choose', 'dump', 'dumps', 'fill', 'getfield',
@@ -267,7 +268,7 @@
         for attrib in dir(a):
             if attrib.startswith('_') or attrib in excluded_methods:
                 continue
-            f = eval('a.%s' % attrib)
+            f = getattr(a, attrib)
             if callable(f):
                 # reset contents of a
                 a.astype('f8')



More information about the Numpy-svn mailing list