[Numpy-svn] r5994 - in trunk/numpy/core: . src tests

numpy-svn@scip... numpy-svn@scip...
Sun Nov 9 18:28:24 CST 2008


Author: stefan
Date: 2008-11-09 18:28:04 -0600 (Sun, 09 Nov 2008)
New Revision: 5994

Added:
   trunk/numpy/core/src/umath_tests.c.src
Modified:
   trunk/numpy/core/setup.py
   trunk/numpy/core/tests/test_ufunc.py
Log:
Add tests for generalized ufuncs.

Modified: trunk/numpy/core/setup.py
===================================================================
--- trunk/numpy/core/setup.py	2008-11-10 00:27:06 UTC (rev 5993)
+++ trunk/numpy/core/setup.py	2008-11-10 00:28:04 UTC (rev 5994)
@@ -399,7 +399,15 @@
                          extra_info = blas_info
                          )
 
+    config.add_extension('umath_tests',
+                         sources = [join('src','umath_tests.c.src'),
+                                    ],
+                         depends = [join('blasdot','cblas.h'),] + deps,
+                         include_dirs = ['blasdot'],
+                         extra_info = blas_info
+                         )
 
+
     config.add_data_dir('tests')
     config.add_data_dir('tests/data')
 

Added: trunk/numpy/core/src/umath_tests.c.src
===================================================================
--- trunk/numpy/core/src/umath_tests.c.src	2008-11-10 00:27:06 UTC (rev 5993)
+++ trunk/numpy/core/src/umath_tests.c.src	2008-11-10 00:28:04 UTC (rev 5994)
@@ -0,0 +1,417 @@
+/* -*- c -*- */
+
+/*
+ *****************************************************************************
+ **                            INCLUDES                                     **
+ *****************************************************************************
+ */
+#include "Python.h"
+#include "numpy/arrayobject.h"
+#include "numpy/ufuncobject.h"
+
+#ifndef CBLAS_HEADER
+#define CBLAS_HEADER "cblas.h"
+#endif
+#include CBLAS_HEADER
+
+/*
+ *****************************************************************************
+ **                            BASICS                                       **
+ *****************************************************************************
+ */
+
+typedef npy_intp intp;
+
+#define INIT_OUTER_LOOP_1       \
+    intp dN = *dimensions++;    \
+    intp N_;                    \
+    intp s0 = *steps++;
+
+#define INIT_OUTER_LOOP_2       \
+    INIT_OUTER_LOOP_1           \
+    intp s1 = *steps++;
+
+#define INIT_OUTER_LOOP_3       \
+    INIT_OUTER_LOOP_2           \
+    intp s2 = *steps++;
+
+#define INIT_OUTER_LOOP_4       \
+    INIT_OUTER_LOOP_3           \
+    intp s3 = *steps++;
+
+#define BEGIN_OUTER_LOOP_3      \
+    for (N_ = 0; N_ < dN; N_++, args[0] += s0, args[1] += s1, args[2] += s2) {
+
+#define BEGIN_OUTER_LOOP_4      \
+    for (N_ = 0; N_ < dN; N_++, args[0] += s0, args[1] += s1, args[2] += s2, args[3] += s3) {
+
+#define END_OUTER_LOOP  }
+
+
+/*
+ *****************************************************************************
+ **                             UFUNC LOOPS                                 **
+ *****************************************************************************
+ */
+
+char *inner1d_signature = "(i),(i)->()";
+
+/**begin repeat
+
+   #TYPE=LONG,DOUBLE#
+   #typ=npy_long, npy_double#
+*/
+
+/*
+ *  This implements the function
+ *        out[n] = sum_i { in1[n, i] * in2[n, i] }.
+ */
+static void
+@TYPE@_inner1d(char **args, intp *dimensions, intp *steps, void *func)
+{
+    INIT_OUTER_LOOP_3
+    intp di = dimensions[0];
+    intp i;
+    intp is1=steps[0], is2=steps[1];
+    BEGIN_OUTER_LOOP_3
+        char *ip1=args[0], *ip2=args[1], *op=args[2];
+        @typ@ sum = 0;
+        for (i = 0; i < di; i++) {
+            sum += (*(@typ@ *)ip1) * (*(@typ@ *)ip2);
+            ip1 += is1;
+            ip2 += is2;
+        }
+        *(@typ@ *)op = sum;
+    END_OUTER_LOOP
+}
+
+/**end repeat**/
+
+char *innerwt_signature = "(i),(i),(i)->()";
+
+/**begin repeat
+
+   #TYPE=LONG,DOUBLE#
+   #typ=npy_long, npy_double#
+*/
+
+
+/*
+ *  This implements the function
+ *        out[n] = sum_i { in1[n, i] * in2[n, i] * in3[n, i] }.
+ */
+
+static void
+@TYPE@_innerwt(char **args, intp *dimensions, intp *steps, void *func)
+{
+    INIT_OUTER_LOOP_4
+    intp di = dimensions[0];
+    intp i;
+    intp is1=steps[0], is2=steps[1], is3=steps[2];
+    BEGIN_OUTER_LOOP_4
+        char *ip1=args[0], *ip2=args[1], *ip3=args[2], *op=args[3];
+        @typ@ sum = 0;
+        for (i = 0; i < di; i++) {
+            sum += (*(@typ@ *)ip1) * (*(@typ@ *)ip2) * (*(@typ@ *)ip3);
+            ip1 += is1;
+            ip2 += is2;
+            ip3 += is3;
+        }
+        *(@typ@ *)op = sum;
+    END_OUTER_LOOP
+}
+
+/**end repeat**/
+
+char *matrix_multiply_signature = "(m,n),(n,p)->(m,p)";
+
+/**begin repeat
+
+   #TYPE=LONG#
+   #typ=npy_long#
+*/
+
+/*
+ *  This implements the function
+ *        out[k, m, p] = sum_n { in1[k, m, n] * in2[k, n, p] }.
+ */
+
+
+static void
+@TYPE@_matrix_multiply(char **args, intp *dimensions, intp *steps, void *func)
+{
+    /* no BLAS is available */
+    INIT_OUTER_LOOP_3
+    intp dm = dimensions[0];
+    intp dn = dimensions[1];
+    intp dp = dimensions[2];
+    intp m,n,p;
+    intp is1_m=steps[0], is1_n=steps[1], is2_n=steps[2], is2_p=steps[3],
+         os_m=steps[4], os_p=steps[5];
+    intp ib1_n = is1_n*dn;
+    intp ib2_n = is2_n*dn;
+    intp ib2_p = is2_p*dp;
+    intp ob_p  = os_p *dp;
+    BEGIN_OUTER_LOOP_3
+        char *ip1=args[0], *ip2=args[1], *op=args[2];
+        for (m = 0; m < dm; m++) {
+            for (n = 0; n < dn; n++) {
+                register @typ@ val1 = (*(@typ@ *)ip1);
+                for (p = 0; p < dp; p++) {
+                    if (n == 0) *(@typ@ *)op = 0;
+                    *(@typ@ *)op += val1 * (*(@typ@ *)ip2);
+                    ip2 += is2_p;
+                    op  +=  os_p;
+                }
+                ip2 -= ib2_p;
+                op  -=  ob_p;
+                ip1 += is1_n;
+                ip2 += is2_n;
+            }
+            ip1 -= ib1_n;
+            ip2 -= ib2_n;
+            ip1 += is1_m;
+            op  +=  os_m;
+        }
+    END_OUTER_LOOP
+}
+
+/**end repeat**/
+
+/**begin repeat
+
+   #TYPE=FLOAT,DOUBLE#
+   #B_TYPE=s, d#
+   #typ=npy_float, npy_double#
+*/
+
+static void
+@TYPE@_matrix_multiply(char **args, intp *dimensions, intp *steps, void *func)
+{
+    INIT_OUTER_LOOP_3
+    intp dm = dimensions[0];
+    intp dn = dimensions[1];
+    intp dp = dimensions[2];
+    intp m,n,p;
+    intp is1_m=steps[0], is1_n=steps[1], is2_n=steps[2], is2_p=steps[3],
+         os_m=steps[4], os_p=steps[5];
+    intp ib1_n = is1_n*dn;
+    intp ib2_n = is2_n*dn;
+    intp ib2_p = is2_p*dp;
+    intp ob_p  = os_p *dp;
+
+    enum CBLAS_ORDER Order = CblasRowMajor;
+    enum CBLAS_TRANSPOSE Trans1, Trans2;
+    int M, N, L;
+    int lda, ldb, ldc;
+    int typeSize = sizeof(@typ@);
+
+    /*
+     * BLAS requires each array to have contiguous memory layout on one
+     * dimension and a positive stride for the other dimension.
+     */
+    if (is1_m <= 0 || is1_n <= 0 || is2_n <= 0 || is2_p <= 0)
+	goto no_blas;
+
+    if (is1_n == typeSize && is1_m % typeSize == 0) {
+	Trans1 = CblasNoTrans;
+	lda = is1_m / typeSize;
+    }
+    else if (is1_m == typeSize && is1_n % typeSize == 0) {
+	Trans1 = CblasTrans;
+	lda = is1_n / typeSize;
+    }
+    else {
+	goto no_blas;
+    }
+
+    if (is2_p == typeSize && is2_n % typeSize == 0) {
+	Trans2 = CblasNoTrans;
+	ldb = is2_n / typeSize;
+    }
+    else if (is2_n == typeSize && is2_p % typeSize == 0) {
+	Trans2 = CblasTrans;
+	ldb = is2_p / typeSize;
+    }
+    else {
+	goto no_blas;
+    }
+
+    M = dm;
+    N = dp;
+    L = dn;
+    if (os_p == typeSize && os_m % typeSize == 0) {
+	ldc = os_m / typeSize;
+	BEGIN_OUTER_LOOP_3
+	    cblas_@B_TYPE@gemm(Order, Trans1, Trans2,
+			       M, N, L,
+			       1.0, (@typ@*)args[0], lda,
+			       (@typ@*)args[1], ldb,
+			       0.0, (@typ@*)args[2], ldc);
+	END_OUTER_LOOP
+	return;
+    }
+    else if (os_m == typeSize && os_p % typeSize == 0) {
+	enum CBLAS_TRANSPOSE Trans1r, Trans2r;
+	ldc = os_p / typeSize;
+	Trans1r = (Trans1 == CblasTrans) ? CblasNoTrans : CblasTrans;
+	Trans2r = (Trans2 == CblasTrans) ? CblasNoTrans : CblasTrans;
+	BEGIN_OUTER_LOOP_3
+	    /* compute C^T = B^T * A^T */
+	    cblas_@B_TYPE@gemm(Order, Trans2r, Trans1r,
+			       N, M, L,
+			       1.0, (@typ@*)args[1], ldb,
+			       (@typ@*)args[0], lda,
+			       0.0, (@typ@*)args[2], ldc);
+	END_OUTER_LOOP
+	return;
+    }
+
+
+no_blas:
+    BEGIN_OUTER_LOOP_3
+        char *ip1=args[0], *ip2=args[1], *op=args[2];
+        for (m = 0; m < dm; m++) {
+            for (n = 0; n < dn; n++) {
+                register @typ@ val1 = (*(@typ@ *)ip1);
+                for (p = 0; p < dp; p++) {
+                    if (n == 0) *(@typ@ *)op = 0;
+                    *(@typ@ *)op += val1 * (*(@typ@ *)ip2);
+                    ip2 += is2_p;
+                    op  +=  os_p;
+                }
+                ip2 -= ib2_p;
+                op  -=  ob_p;
+                ip1 += is1_n;
+                ip2 += is2_n;
+            }
+            ip1 -= ib1_n;
+            ip2 -= ib2_n;
+            ip1 += is1_m;
+            op  +=  os_m;
+        }
+    END_OUTER_LOOP
+}
+
+/**end repeat**/
+
+/*  The following lines were generated using a slightly modified
+    version of code_generators/generate_umath.py and adding these
+    lines to defdict:
+
+defdict = {
+'inner1d' :
+    Ufunc(2, 1, None_,
+        r'''inner on the last dimension and broadcast on the rest \n"
+        "     \"(i),(i)->()\" \n''',
+          TD('ld'),
+          ),
+'innerwt' :
+    Ufunc(3, 1, None_,
+          r'''inner1d with a weight argument \n"
+          "     \"(i),(i),(i)->()\" \n''',
+          TD('ld'),
+          ),
+}
+
+*/
+
+static PyUFuncGenericFunction inner1d_functions[] = { LONG_inner1d, DOUBLE_inner1d };
+static void * inner1d_data[] = { (void *)NULL, (void *)NULL };
+static char inner1d_signatures[] = { PyArray_LONG, PyArray_LONG, PyArray_LONG, PyArray_DOUBLE, PyArray_DOUBLE, PyArray_DOUBLE };
+static PyUFuncGenericFunction innerwt_functions[] = { LONG_innerwt, DOUBLE_innerwt };
+static void * innerwt_data[] = { (void *)NULL, (void *)NULL };
+static char innerwt_signatures[] = { PyArray_LONG, PyArray_LONG, PyArray_LONG, PyArray_LONG, PyArray_DOUBLE, PyArray_DOUBLE, PyArray_DOUBLE, PyArray_DOUBLE };
+static PyUFuncGenericFunction matrix_multiply_functions[] = { LONG_matrix_multiply, FLOAT_matrix_multiply, DOUBLE_matrix_multiply };
+static void *matrix_multiply_data[] = { (void *)NULL, (void *)NULL, (void *)NULL };
+static char matrix_multiply_signatures[] = { PyArray_LONG, PyArray_LONG, PyArray_LONG,  PyArray_FLOAT, PyArray_FLOAT, PyArray_FLOAT,  PyArray_DOUBLE, PyArray_DOUBLE, PyArray_DOUBLE };
+
+static void
+addUfuncs(PyObject *dictionary) {
+    PyObject *f;
+
+    f = PyUFunc_FromFuncAndDataAndSignature(inner1d_functions, inner1d_data, inner1d_signatures, 2,
+                                    2, 1, PyUFunc_None, "inner1d",
+                                    "inner on the last dimension and broadcast on the rest \n"\
+                                    "     \"(i),(i)->()\" \n",
+                                    0, inner1d_signature);
+    PyDict_SetItemString(dictionary, "inner1d", f);
+    Py_DECREF(f);
+    f = PyUFunc_FromFuncAndDataAndSignature(innerwt_functions, innerwt_data, innerwt_signatures, 2,
+                                    3, 1, PyUFunc_None, "innerwt",
+                                    "inner1d with a weight argument \n"\
+                                    "     \"(i),(i),(i)->()\" \n",
+                                    0, innerwt_signature);
+    PyDict_SetItemString(dictionary, "innerwt", f);
+    Py_DECREF(f);
+    f = PyUFunc_FromFuncAndDataAndSignature(matrix_multiply_functions,
+                                    matrix_multiply_data, matrix_multiply_signatures,
+                                    3, 2, 1, PyUFunc_None, "matrix_multiply",
+                                    "matrix multiplication on last two dimensions \n"\
+                                    "     \"(m,n),(n,p)->(m,p)\" \n",
+                                    0, matrix_multiply_signature);
+    PyDict_SetItemString(dictionary, "matrix_multiply", f);
+    Py_DECREF(f);
+}
+
+/*
+    End of auto-generated code.
+*/
+
+
+
+static PyObject *
+UMath_Tests_test_signature(PyObject *dummy, PyObject *args)
+{
+    int nin, nout;
+    PyObject *signature;
+    PyObject *f;
+    int core_enabled;
+
+    if (!PyArg_ParseTuple(args, "iiO", &nin, &nout, &signature)) return NULL;
+    f = PyUFunc_FromFuncAndDataAndSignature(NULL, NULL, NULL,
+        0, nin, nout, PyUFunc_None, "no name",
+        "doc:none",
+        1, PyString_AS_STRING(signature));
+    if (f == NULL) return NULL;
+    core_enabled = ((PyUFuncObject*)f)->core_enabled;
+    return Py_BuildValue("i", core_enabled);
+}
+
+static PyMethodDef UMath_TestsMethods[] = {
+    {"test_signature",  UMath_Tests_test_signature, METH_VARARGS,
+     "Test signature parsing of ufunc. \n"
+     "Arguments: nin nout signature \n"
+     "If fails, it returns NULL. Otherwise it will returns 0 for scalar ufunc "
+     "and 1 for generalized ufunc. \n",
+     },
+    {NULL, NULL, 0, NULL}        /* Sentinel */
+};
+
+PyMODINIT_FUNC
+initumath_tests(void)
+{
+    PyObject *m;
+    PyObject *d;
+    PyObject *version;
+
+    m = Py_InitModule("umath_tests", UMath_TestsMethods);
+    if (m == NULL) return;
+
+    import_array();
+    import_ufunc();
+
+    d = PyModule_GetDict(m);
+
+    version = PyString_FromString("0.1");
+    PyDict_SetItemString(d, "__version__", version);
+    Py_DECREF(version);
+
+    /* Load the ufunc operators into the module's namespace */
+    addUfuncs(d);
+
+    if (PyErr_Occurred()) {
+        PyErr_SetString(PyExc_RuntimeError,
+                        "cannot load umath_tests module.");
+    }
+}

Modified: trunk/numpy/core/tests/test_ufunc.py
===================================================================
--- trunk/numpy/core/tests/test_ufunc.py	2008-11-10 00:27:06 UTC (rev 5993)
+++ trunk/numpy/core/tests/test_ufunc.py	2008-11-10 00:28:04 UTC (rev 5994)
@@ -1,5 +1,6 @@
 import numpy as np
 from numpy.testing import *
+import numpy.core.umath_tests as umt
 
 class TestUfunc(TestCase):
     def test_reduceat_shifting_sum(self) :
@@ -230,6 +231,193 @@
         """
         pass
 
+    def test_signature(self):
+        # the arguments to test_signature are: nin, nout, core_signature
+        # pass
+        assert_equal(umt.test_signature(2,1,"(i),(i)->()"), 1)
 
+        # pass. empty core signature; treat as plain ufunc (with trivial core)
+        assert_equal(umt.test_signature(2,1,"(),()->()"), 0)
+
+        # in the following calls, a ValueError should be raised because
+        # of error in core signature
+        # error: extra parenthesis
+        msg = "core_sig: extra parenthesis"
+        try:
+            ret = umt.test_signature(2,1,"((i)),(i)->()")
+            assert_equal(ret, None, err_msg=msg)
+        except ValueError: None
+        # error: parenthesis matching
+        msg = "core_sig: parenthesis matching"
+        try:
+            ret = umt.test_signature(2,1,"(i),)i(->()")
+            assert_equal(ret, None, err_msg=msg)
+        except ValueError: None
+        # error: incomplete signature. letters outside of parenthesis are ignored
+        msg = "core_sig: incomplete signature"
+        try:
+            ret = umt.test_signature(2,1,"(i),->()")
+            assert_equal(ret, None, err_msg=msg)
+        except ValueError: None
+        # error: incomplete signature. 2 output arguments are specified
+        msg = "core_sig: incomplete signature"
+        try:
+            ret = umt.test_signature(2,2,"(i),(i)->()")
+            assert_equal(ret, None, err_msg=msg)
+        except ValueError: None
+
+        # more complicated names for variables
+        assert_equal(umt.test_signature(2,1,"(i1,i2),(J_1)->(_kAB)"),1)
+
+    def test_get_signature(self):
+        assert_equal(umt.inner1d.signature, "(i),(i)->()")
+
+    def test_inner1d(self):
+        a = np.arange(6).reshape((2,3))
+        assert_array_equal(umt.inner1d(a,a), np.sum(a*a,axis=-1))
+
+    def test_broadcast(self):
+        msg = "broadcast"
+        a = np.arange(4).reshape((2,1,2))
+        b = np.arange(4).reshape((1,2,2))
+        assert_array_equal(umt.inner1d(a,b), np.sum(a*b,axis=-1), err_msg=msg)
+        msg = "extend & broadcast loop dimensions"
+        b = np.arange(4).reshape((2,2))
+        assert_array_equal(umt.inner1d(a,b), np.sum(a*b,axis=-1), err_msg=msg)
+        msg = "broadcast in core dimensions"
+        a = np.arange(8).reshape((4,2))
+        b = np.arange(4).reshape((4,1))
+        assert_array_equal(umt.inner1d(a,b), np.sum(a*b,axis=-1), err_msg=msg)
+        msg = "extend & broadcast core and loop dimensions"
+        a = np.arange(8).reshape((4,2))
+        b = np.array(7)
+        assert_array_equal(umt.inner1d(a,b), np.sum(a*b,axis=-1), err_msg=msg)
+        msg = "broadcast should fail"
+        a = np.arange(2).reshape((2,1,1))
+        b = np.arange(3).reshape((3,1,1))
+        try:
+            ret = umt.inner1d(a,b)
+            assert_equal(ret, None, err_msg=msg)
+        except ValueError: None
+
+    def test_type_cast(self):
+        msg = "type cast"
+        a = np.arange(6, dtype='short').reshape((2,3))
+        assert_array_equal(umt.inner1d(a,a), np.sum(a*a,axis=-1), err_msg=msg)
+        msg = "type cast on one argument"
+        a = np.arange(6).reshape((2,3))
+        b = a+0.1
+        assert_array_almost_equal(umt.inner1d(a,a), np.sum(a*a,axis=-1),
+            err_msg=msg)
+
+    def test_endian(self):
+        msg = "big endian"
+        a = np.arange(6, dtype='>i4').reshape((2,3))
+        assert_array_equal(umt.inner1d(a,a), np.sum(a*a,axis=-1), err_msg=msg)
+        msg = "little endian"
+        a = np.arange(6, dtype='<i4').reshape((2,3))
+        assert_array_equal(umt.inner1d(a,a), np.sum(a*a,axis=-1), err_msg=msg)
+
+    def test_incontiguous_array(self):
+        msg = "incontiguous memory layout of array"
+        x = np.arange(64).reshape((2,2,2,2,2,2))
+        a = x[:,0,:,0,:,0]
+        b = x[:,1,:,1,:,1]
+        a[0,0,0] = -1
+        msg2 = "make sure it references to the original array"
+        assert_equal(x[0,0,0,0,0,0], -1, err_msg=msg2)
+        assert_array_equal(umt.inner1d(a,b), np.sum(a*b,axis=-1), err_msg=msg)
+        x = np.arange(24).reshape(2,3,4)
+        a = x.T
+        b = x.T
+        a[0,0,0] = -1
+        assert_equal(x[0,0,0], -1, err_msg=msg2)
+        assert_array_equal(umt.inner1d(a,b), np.sum(a*b,axis=-1), err_msg=msg)
+
+    def test_output_argument(self):
+        msg = "output argument"
+        a = np.arange(12).reshape((2,3,2))
+        b = np.arange(4).reshape((2,1,2)) + 1
+        c = np.zeros((2,3),dtype='int')
+        umt.inner1d(a,b,c)
+        assert_array_equal(c, np.sum(a*b,axis=-1), err_msg=msg)
+        msg = "output argument with type cast"
+        c = np.zeros((2,3),dtype='int16')
+        umt.inner1d(a,b,c)
+        assert_array_equal(c, np.sum(a*b,axis=-1), err_msg=msg)
+        msg = "output argument with incontiguous layout"
+        c = np.zeros((2,3,4),dtype='int16')
+        umt.inner1d(a,b,c[...,0])
+        assert_array_equal(c[...,0], np.sum(a*b,axis=-1), err_msg=msg)
+
+    def test_innerwt(self):
+        a = np.arange(6).reshape((2,3))
+        b = np.arange(10,16).reshape((2,3))
+        w = np.arange(20,26).reshape((2,3))
+        assert_array_equal(umt.innerwt(a,b,w), np.sum(a*b*w,axis=-1))
+        a = np.arange(100,124).reshape((2,3,4))
+        b = np.arange(200,224).reshape((2,3,4))
+        w = np.arange(300,324).reshape((2,3,4))
+        assert_array_equal(umt.innerwt(a,b,w), np.sum(a*b*w,axis=-1))
+
+    def test_matrix_multiply(self):
+        self.compare_matrix_multiply_results(np.long)
+        self.compare_matrix_multiply_results(np.double)
+
+    def compare_matrix_multiply_results(self, tp):
+        d1 = np.array(rand(2,3,4), dtype=tp)
+        d2 = np.array(rand(2,3,4), dtype=tp)
+        msg = "matrix multiply on type %s" % d1.dtype.name
+
+        def permute_n(n):
+            if n == 1:
+                return ([0],)
+            ret = ()
+            base = permute_n(n-1)
+            for perm in base:
+                for i in xrange(n):
+                    new = perm + [n-1]
+                    new[n-1] = new[i]
+                    new[i] = n-1
+                    ret += (new,)
+            return ret
+
+        def slice_n(n):
+            if n == 0:
+                return ((),)
+            ret = ()
+            base = slice_n(n-1)
+            for sl in base:
+                ret += (sl+(slice(None),),)
+                ret += (sl+(slice(0,1),),)
+            return ret
+
+        def broadcastable(s1,s2):
+            return s1 == s2 or s1 == 1 or s2 == 1
+
+        permute_3 = permute_n(3)
+        slice_3 = slice_n(3) + ((slice(None,None,-1),)*3,)
+
+        ref = True
+        for p1 in permute_3:
+            for p2 in permute_3:
+                for s1 in slice_3:
+                    for s2 in slice_3:
+                        a1 = d1.transpose(p1)[s1]
+                        a2 = d2.transpose(p2)[s2]
+                        ref = ref and a1.base != None and a1.base.base != None
+                        ref = ref and a2.base != None and a2.base.base != None
+                        if broadcastable(a1.shape[-1], a2.shape[-2]) and \
+                           broadcastable(a1.shape[0], a2.shape[0]):
+                            assert_array_almost_equal(
+                                umt.matrix_multiply(a1,a2),
+                                np.sum(a2[...,np.newaxis].swapaxes(-3,-1) *
+                                       a1[...,np.newaxis,:], axis=-1),
+                                err_msg = msg+' %s %s' % (str(a1.shape),
+                                                          str(a2.shape)))
+
+        assert_equal(ref, True, err_msg="reference check")
+
+
 if __name__ == "__main__":
     run_module_suite()



More information about the Numpy-svn mailing list