[Numpy-svn] r5928 - in branches/ufunc_cleanup/numpy: core core/code_generators core/include/numpy core/src core/tests distutils/command

numpy-svn@scip... numpy-svn@scip...
Sun Oct 5 13:29:28 CDT 2008


Author: charris
Date: 2008-10-05 13:29:19 -0500 (Sun, 05 Oct 2008)
New Revision: 5928

Added:
   branches/ufunc_cleanup/numpy/core/src/math_c99.inc.src
Removed:
   branches/ufunc_cleanup/numpy/core/src/_isnan.c
Modified:
   branches/ufunc_cleanup/numpy/core/SConscript
   branches/ufunc_cleanup/numpy/core/code_generators/numpy_api_order.txt
   branches/ufunc_cleanup/numpy/core/include/numpy/ufuncobject.h
   branches/ufunc_cleanup/numpy/core/setup.py
   branches/ufunc_cleanup/numpy/core/src/_signbit.c
   branches/ufunc_cleanup/numpy/core/src/arrayobject.c
   branches/ufunc_cleanup/numpy/core/src/multiarraymodule.c
   branches/ufunc_cleanup/numpy/core/tests/test_multiarray.py
   branches/ufunc_cleanup/numpy/distutils/command/config.py
Log:
Merge up to r5926.

Modified: branches/ufunc_cleanup/numpy/core/SConscript
===================================================================
--- branches/ufunc_cleanup/numpy/core/SConscript	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/SConscript	2008-10-05 18:29:19 UTC (rev 5928)
@@ -1,4 +1,4 @@
-# Last Change: Tue Aug 05 12:00 PM 2008 J
+# Last Change: Fri Oct 03 04:00 PM 2008 J
 # vim:syntax=python
 import os
 import sys
@@ -136,39 +136,56 @@
 # Set value to 1 for each defined function (in math lib)
 mfuncs_defined = dict([(f, 0) for f in mfuncs])
 
-# TODO: checklib vs checkfunc ?
-def check_func(f):
-    """Check that f is available in mlib, and add the symbol appropriately.  """
-    st = config.CheckDeclaration(f, language = 'C', includes = "#include <math.h>")
-    if st:
-        st = config.CheckFunc(f, language = 'C')
-    if st:
-        mfuncs_defined[f] = 1
-    else:
-        mfuncs_defined[f] = 0
+# Check for mandatory funcs: we barf if a single one of those is not there
+mandatory_funcs = ["sin", "cos", "tan", "sinh", "cosh", "tanh", "fabs",
+"floor", "ceil", "sqrt", "log10", "log", "exp", "asin", "acos", "atan", "fmod",
+'modf', 'frexp', 'ldexp']
 
-for f in mfuncs:
-    check_func(f)
+if not config.CheckFuncsAtOnce(mandatory_funcs):
+    raise SystemError("One of the required function to build numpy is not"
+            " available (the list is %s)." % str(mandatory_funcs))
 
-if mfuncs_defined['expl'] == 1:
-    config.Define('HAVE_LONGDOUBLE_FUNCS',
-                  comment = 'Define to 1 if long double funcs are available')
-if mfuncs_defined['expf'] == 1:
-    config.Define('HAVE_FLOAT_FUNCS',
-                  comment = 'Define to 1 if long double funcs are available')
-if mfuncs_defined['asinh'] == 1:
-    config.Define('HAVE_INVERSE_HYPERBOLIC',
-                  comment = 'Define to 1 if inverse hyperbolic funcs are '\
-                            'available')
-if mfuncs_defined['atanhf'] == 1:
-    config.Define('HAVE_INVERSE_HYPERBOLIC_FLOAT',
-                  comment = 'Define to 1 if inverse hyperbolic float funcs '\
-                            'are available')
-if mfuncs_defined['atanhl'] == 1:
-    config.Define('HAVE_INVERSE_HYPERBOLIC_LONGDOUBLE',
-                  comment = 'Define to 1 if inverse hyperbolic long double '\
-                            'funcs are available')
+# Standard functions which may not be available and for which we have a
+# replacement implementation
+#
+def check_funcs(funcs):
+    # Use check_funcs_once first, and if it does not work, test func per
+    # func. Return success only if all the functions are available
+    st = config.CheckFuncsAtOnce(funcs)
+    if not st:
+        # Global check failed, check func per func
+        for f in funcs:
+            st = config.CheckFunc(f, language = 'C')
 
+# XXX: we do not test for hypot because python checks for it (HAVE_HYPOT in
+# python.h... I wish they would clean their public headers someday)
+optional_stdfuncs = ["expm1", "log1p", "acosh", "asinh", "atanh",
+                     "rint", "trunc"]
+
+check_funcs(optional_stdfuncs)
+
+# C99 functions: float and long double versions
+c99_funcs = ["sin", "cos", "tan", "sinh", "cosh", "tanh", "fabs", "floor",
+             "ceil", "rint", "trunc", "sqrt", "log10", "log", "exp",
+             "expm1", "asin", "acos", "atan", "asinh", "acosh", "atanh",
+             "hypot", "atan2", "pow", "fmod", "modf", 'frexp', 'ldexp']
+
+for prec in ['l', 'f']:
+    fns = [f + prec for f in c99_funcs]
+    check_funcs(fns)
+
+# Normally, isnan and isinf are macro (C99), but some platforms only have
+# func, or both func and macro version. Check for macro only, and define
+# replacement ones if not found.
+# Note: including Python.h is necessary because it modifies some math.h
+# definitions
+for f in ["isnan", "isinf", "signbit", "isfinite"]:
+    includes = """\
+#include <Python.h>
+#include <math.h>
+"""
+    config.CheckDeclaration(f, includes=includes)
+
 #-------------------------------------------------------
 # Define the function PyOS_ascii_strod if not available
 #-------------------------------------------------------
@@ -234,6 +251,7 @@
 # Generate generated code
 #------------------------
 scalartypes_src = env.GenerateFromTemplate(pjoin('src', 'scalartypes.inc.src'))
+math_c99_src = env.GenerateFromTemplate(pjoin('src', 'math_c99.inc.src'))
 arraytypes_src = env.GenerateFromTemplate(pjoin('src', 'arraytypes.inc.src'))
 sortmodule_src = env.GenerateFromTemplate(pjoin('src', '_sortmodule.c.src'))
 umathmodule_src = env.GenerateFromTemplate(pjoin('src', 'umathmodule.c.src'))

Modified: branches/ufunc_cleanup/numpy/core/code_generators/numpy_api_order.txt
===================================================================
--- branches/ufunc_cleanup/numpy/core/code_generators/numpy_api_order.txt	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/code_generators/numpy_api_order.txt	2008-10-05 18:29:19 UTC (rev 5928)
@@ -170,3 +170,4 @@
 PyArray_CheckAxis
 PyArray_OverflowMultiplyList
 PyArray_CompareString
+PyArray_MultiIterFromObjects

Modified: branches/ufunc_cleanup/numpy/core/include/numpy/ufuncobject.h
===================================================================
--- branches/ufunc_cleanup/numpy/core/include/numpy/ufuncobject.h	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/include/numpy/ufuncobject.h	2008-10-05 18:29:19 UTC (rev 5928)
@@ -263,11 +263,6 @@
 		| ((SW_INVALID & fpstatus) ? UFUNC_FPE_INVALID : 0);	\
 	}
 
-#define isnan(x) (_isnan((double)(x)))
-#define isinf(x) ((_fpclass((double)(x)) == _FPCLASS_PINF) ||	\
-		  (_fpclass((double)(x)) == _FPCLASS_NINF))
-#define isfinite(x) (_finite((double) x))
-
 /* Solaris --------------------------------------------------------*/
 /* --------ignoring SunOS ieee_flags approach, someone else can
 **         deal with that! */

Modified: branches/ufunc_cleanup/numpy/core/setup.py
===================================================================
--- branches/ufunc_cleanup/numpy/core/setup.py	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/setup.py	2008-10-05 18:29:19 UTC (rev 5928)
@@ -5,20 +5,6 @@
 from numpy.distutils import log
 from distutils.dep_util import newer
 
-FUNCTIONS_TO_CHECK = [
-    ('expl', 'HAVE_LONGDOUBLE_FUNCS'),
-    ('expf', 'HAVE_FLOAT_FUNCS'),
-    ('log1p', 'HAVE_LOG1P'),
-    ('expm1', 'HAVE_EXPM1'),
-    ('asinh', 'HAVE_INVERSE_HYPERBOLIC'),
-    ('atanhf', 'HAVE_INVERSE_HYPERBOLIC_FLOAT'),
-    ('atanhl', 'HAVE_INVERSE_HYPERBOLIC_LONGDOUBLE'),
-    ('isnan', 'HAVE_ISNAN'),
-    ('isinf', 'HAVE_ISINF'),
-    ('rint', 'HAVE_RINT'),
-    ('trunc', 'HAVE_TRUNC'),
-    ]
-
 def is_npy_no_signal():
     """Return True if the NPY_NO_SIGNAL symbol must be defined in configuration
     header."""
@@ -49,6 +35,75 @@
             nosmp = 0
     return nosmp == 1
 
+def check_math_capabilities(config, moredefs, mathlibs):
+    def check_func(func_name):
+        return config.check_func(func_name, libraries=mathlibs,
+                                 decl=True, call=True)
+
+    def check_funcs_once(funcs_name):
+        decl = dict([(f, True) for f in funcs_name])
+        st = config.check_funcs_once(funcs_name, libraries=mathlibs,
+                                     decl=decl, call=decl)
+        if st:
+            moredefs.extend([name_to_defsymb(f) for f in funcs_name])
+        return st
+
+    def check_funcs(funcs_name):
+        # Use check_funcs_once first, and if it does not work, test func per
+        # func. Return success only if all the functions are available
+        if not check_funcs_once(funcs_name):
+            # Global check failed, check func per func
+            for f in funcs_name:
+                if check_func(f):
+                    moredefs.append(name_to_defsymb(f))
+            return 0
+        else:
+            return 1
+
+    def name_to_defsymb(name):
+        return "HAVE_%s" % name.upper()
+
+    #use_msvc = config.check_decl("_MSC_VER")
+
+    # Mandatory functions: if not found, fail the build
+    mandatory_funcs = ["sin", "cos", "tan", "sinh", "cosh", "tanh", "fabs",
+		"floor", "ceil", "sqrt", "log10", "log", "exp", "asin",
+		"acos", "atan", "fmod", 'modf', 'frexp', 'ldexp']
+
+    if not check_funcs_once(mandatory_funcs):
+        raise SystemError("One of the required function to build numpy is not"
+                " available (the list is %s)." % str(mandatory_funcs))
+
+    # Standard functions which may not be available and for which we have a
+    # replacement implementation
+    # XXX: we do not test for hypot because python checks for it (HAVE_HYPOT in
+    # python.h... I wish they would clean their public headers someday)
+    optional_stdfuncs = ["expm1", "log1p", "acosh", "asinh", "atanh",
+                         "rint", "trunc"]
+
+    check_funcs(optional_stdfuncs)
+
+    # C99 functions: float and long double versions
+    c99_funcs = ["sin", "cos", "tan", "sinh", "cosh", "tanh", "fabs", "floor",
+                 "ceil", "rint", "trunc", "sqrt", "log10", "log", "exp",
+                 "expm1", "asin", "acos", "atan", "asinh", "acosh", "atanh",
+                 "hypot", "atan2", "pow", "fmod", "modf", 'frexp', 'ldexp']
+
+    for prec in ['l', 'f']:
+        fns = [f + prec for f in c99_funcs]
+        check_funcs(fns)
+
+    # Normally, isnan and isinf are macro (C99), but some platforms only have
+    # func, or both func and macro version. Check for macro only, and define
+    # replacement ones if not found.
+    # Note: including Python.h is necessary because it modifies some math.h
+    # definitions
+    for f in ["isnan", "isinf", "signbit", "isfinite"]:
+        st = config.check_decl(f, headers = ["Python.h", "math.h"])
+        if st:
+            moredefs.append(name_to_defsymb("decl_%s" % f))
+
+
 def configuration(parent_package='',top_path=None):
     from numpy.distutils.misc_util import Configuration,dot_join
     from numpy.distutils.system_info import get_info, default_lib_dirs
@@ -106,15 +161,8 @@
             ext.libraries.extend(mathlibs)
             moredefs.append(('MATHLIB',','.join(mathlibs)))
 
-            def check_func(func_name):
-                return config_cmd.check_func(func_name,
-                                             libraries=mathlibs, decl=False,
-                                             headers=['math.h'])
+            check_math_capabilities(config_cmd, moredefs, mathlibs)
 
-            for func_name, defsymbol in FUNCTIONS_TO_CHECK:
-                if check_func(func_name):
-                    moredefs.append(defsymbol)
-
             if is_npy_no_signal():
                 moredefs.append('__NPY_PRIVATE_NO_SIGNAL')
 
@@ -136,6 +184,17 @@
                     target_f.write('#define %s\n' % (d))
                 else:
                     target_f.write('#define %s %s\n' % (d[0],d[1]))
+
+            # Keep those for backward compatibility for now
+            target_f.write("""
+#ifdef HAVE_EXPL
+#define HAVE_LONGDOUBLE_FUNCS
+#endif
+
+#ifdef HAVE_EXPF
+#define HAVE_FLOAT_FUNCS
+#endif
+""")
             target_f.close()
             print 'File:',target
             target_f = open(target)
@@ -264,7 +323,6 @@
             join('src','scalartypes.inc.src'),
             join('src','arraytypes.inc.src'),
             join('src','_signbit.c'),
-            join('src','_isnan.c'),
             join('src','ucsnarrow.c'),
             join('include','numpy','*object.h'),
             'include/numpy/fenv/fenv.c',
@@ -298,6 +356,7 @@
                                     generate_ufunc_api,
                                     join('src','scalartypes.inc.src'),
                                     join('src','arraytypes.inc.src'),
+                                    join('src','math_c99.inc.src'),
                                     ],
                          depends = [join('src','ufuncobject.c'),
                                     generate_umath_py,

Deleted: branches/ufunc_cleanup/numpy/core/src/_isnan.c
===================================================================
--- branches/ufunc_cleanup/numpy/core/src/_isnan.c	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/src/_isnan.c	2008-10-05 18:29:19 UTC (rev 5928)
@@ -1,46 +0,0 @@
-/* Adapted from cephes */
-
-static int
-isnan(double x)
-{
-    union
-    {
-        double d;
-        unsigned short s[4];
-        unsigned int i[2];
-    } u;
-
-    u.d = x;
-
-#if SIZEOF_INT == 4
-
-#ifdef WORDS_BIGENDIAN /* defined in pyconfig.h */
-    if( ((u.i[0] & 0x7ff00000) == 0x7ff00000)
-        && (((u.i[0] & 0x000fffff) != 0) || (u.i[1] != 0)))
-        return 1;
-#else
-    if( ((u.i[1] & 0x7ff00000) == 0x7ff00000)
-        && (((u.i[1] & 0x000fffff) != 0) || (u.i[0] != 0)))
-        return 1;
-#endif
-
-#else  /* SIZEOF_INT != 4 */
-
-#ifdef WORDS_BIGENDIAN
-    if( (u.s[0] & 0x7ff0) == 0x7ff0)
-        {
-            if( ((u.s[0] & 0x000f) | u.s[1] | u.s[2] | u.s[3]) != 0 )
-                return 1;
-        }
-#else
-    if( (u.s[3] & 0x7ff0) == 0x7ff0) 
-        {
-            if( ((u.s[3] & 0x000f) | u.s[2] | u.s[1] | u.s[0]) != 0 )
-                return 1;
-        }
-#endif
-
-#endif  /* SIZEOF_INT */
-
-    return 0;
-}

Modified: branches/ufunc_cleanup/numpy/core/src/_signbit.c
===================================================================
--- branches/ufunc_cleanup/numpy/core/src/_signbit.c	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/src/_signbit.c	2008-10-05 18:29:19 UTC (rev 5928)
@@ -1,7 +1,7 @@
 /* Adapted from cephes */
 
 static int
-signbit(double x)
+signbit_d(double x)
 {
     union
     {

Modified: branches/ufunc_cleanup/numpy/core/src/arrayobject.c
===================================================================
--- branches/ufunc_cleanup/numpy/core/src/arrayobject.c	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/src/arrayobject.c	2008-10-05 18:29:19 UTC (rev 5928)
@@ -10790,6 +10790,75 @@
 /** END of Subscript Iterator **/
 
 
+/*
+  NUMPY_API
+  Get MultiIterator from array of Python objects and any additional
+
+  PyObject **mps -- array of PyObjects 
+  int n - number of PyObjects in the array
+  int nadd - number of additional arrays to include in the
+             iterator. 
+
+  Returns a multi-iterator object.
+ */
+static PyObject *
+PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
+{
+    va_list va;
+    PyArrayMultiIterObject *multi;
+    PyObject *current;
+    PyObject *arr;
+
+    int i, ntot, err=0;
+
+    ntot = n + nadd;
+    if (ntot < 2 || ntot > NPY_MAXARGS) {
+        PyErr_Format(PyExc_ValueError,
+                     "Need between 2 and (%d) "                 \
+                     "array objects (inclusive).", NPY_MAXARGS);
+        return NULL;
+    }
+
+    multi = _pya_malloc(sizeof(PyArrayMultiIterObject));
+    if (multi == NULL) return PyErr_NoMemory();
+    PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type);
+
+    for(i=0; i<ntot; i++) multi->iters[i] = NULL;
+    multi->numiter = ntot;
+    multi->index = 0;
+
+    va_start(va, nadd);
+    for(i=0; i<ntot; i++) {
+	if (i < n) {
+	    current = mps[i];
+	}
+	else {
+	    current = va_arg(va, PyObject *);
+	}
+        arr = PyArray_FROM_O(current);
+        if (arr==NULL) {
+            err=1; break;
+        }
+        else {
+            multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr);
+            Py_DECREF(arr);
+        }
+    }
+
+    va_end(va);
+
+    if (!err && PyArray_Broadcast(multi) < 0) err=1;
+
+    if (err) {
+        Py_DECREF(multi);
+        return NULL;
+    }
+
+    PyArray_MultiIter_RESET(multi);
+
+    return (PyObject *)multi;  
+}
+
 /*NUMPY_API
   Get MultiIterator,
 */

Copied: branches/ufunc_cleanup/numpy/core/src/math_c99.inc.src (from rev 5926, trunk/numpy/core/src/math_c99.inc.src)

Modified: branches/ufunc_cleanup/numpy/core/src/multiarraymodule.c
===================================================================
--- branches/ufunc_cleanup/numpy/core/src/multiarraymodule.c	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/src/multiarraymodule.c	2008-10-05 18:29:19 UTC (rev 5928)
@@ -2326,50 +2326,40 @@
 PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
                NPY_CLIPMODE clipmode)
 {
-    intp *sizes, offset;
     int n, elsize;
     intp i, m;
     char *ret_data;
     PyArrayObject **mps, *ap;
-    intp *self_data, mi;
+    PyArrayMultiIterObject *multi=NULL;
+    intp mi;
     int copyret=0;
     ap = NULL;
 
     /* Convert all inputs to arrays of a common type */
+    /* Also makes them C-contiguous */
     mps = PyArray_ConvertToCommonType(op, &n);
     if (mps == NULL) return NULL;
 
-    sizes = (intp *)_pya_malloc(n*sizeof(intp));
-    if (sizes == NULL) goto fail;
-
-    ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ip,
-                                                    PyArray_INTP,
-                                                    0, 0);
-    if (ap == NULL) goto fail;
-
-    /* Check the dimensions of the arrays */
     for(i=0; i<n; i++) {
         if (mps[i] == NULL) goto fail;
-        if (ap->nd < mps[i]->nd) {
-            PyErr_SetString(PyExc_ValueError,
-                            "too many dimensions");
-            goto fail;
-        }
-        if (!PyArray_CompareLists(ap->dimensions+(ap->nd-mps[i]->nd),
-                                  mps[i]->dimensions, mps[i]->nd)) {
-            PyErr_SetString(PyExc_ValueError,
-                            "array dimensions must agree");
-            goto fail;
-        }
-        sizes[i] = PyArray_NBYTES(mps[i]);
     }
 
+    ap = (PyArrayObject *)PyArray_FROM_OT((PyObject *)ip, NPY_INTP);
+
+    if (ap == NULL) goto fail;
+
+    /* Broadcast all arrays to each other, index array at the end. */ 
+    multi = (PyArrayMultiIterObject *)\
+	PyArray_MultiIterFromObjects((PyObject **)mps, n, 1, ap);
+    if (multi == NULL) goto fail;
+
+    /* Set-up return array */
     if (!ret) {
         Py_INCREF(mps[0]->descr);
         ret = (PyArrayObject *)PyArray_NewFromDescr(ap->ob_type,
                                                     mps[0]->descr,
-                                                    ap->nd,
-                                                    ap->dimensions,
+                                                    multi->nd,
+                                                    multi->dimensions,
                                                     NULL, NULL, 0,
                                                     (PyObject *)ap);
     }
@@ -2377,8 +2367,10 @@
         PyArrayObject *obj;
         int flags = NPY_CARRAY | NPY_UPDATEIFCOPY | NPY_FORCECAST;
 
-        if (PyArray_SIZE(ret) != PyArray_SIZE(ap)) {
-            PyErr_SetString(PyExc_TypeError,
+        if ((PyArray_NDIM(ret) != multi->nd) || 
+	    !PyArray_CompareLists(PyArray_DIMS(ret), multi->dimensions, 
+				  multi->nd)) {
+	    PyErr_SetString(PyExc_TypeError,
                             "invalid shape for output array.");
             ret = NULL;
             goto fail;
@@ -2399,12 +2391,10 @@
 
     if (ret == NULL) goto fail;
     elsize = ret->descr->elsize;
-    m = PyArray_SIZE(ret);
-    self_data = (intp *)ap->data;
     ret_data = ret->data;
 
-    for (i=0; i<m; i++) {
-        mi = *self_data;
+    while (PyArray_MultiIter_NOTDONE(multi)) {
+	mi = *((intp *)PyArray_MultiIter_DATA(multi, n));
         if (mi < 0 || mi >= n) {
             switch(clipmode) {
             case NPY_RAISE:
@@ -2426,17 +2416,16 @@
                 break;
             }
         }
-        offset = i*elsize;
-        if (offset >= sizes[mi]) {offset = offset % sizes[mi]; }
-        memmove(ret_data, mps[mi]->data+offset, elsize);
-        ret_data += elsize; self_data++;
+        memmove(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize);
+        ret_data += elsize; 
+	PyArray_MultiIter_NEXT(multi);
     }
 
     PyArray_INCREF(ret);
+    Py_DECREF(multi);
     for(i=0; i<n; i++) Py_XDECREF(mps[i]);
     Py_DECREF(ap);
     PyDataMem_FREE(mps);
-    _pya_free(sizes);
     if (copyret) {
         PyObject *obj;
         obj = ret->base;
@@ -2447,10 +2436,10 @@
     return (PyObject *)ret;
 
  fail:
+    Py_XDECREF(multi);
     for(i=0; i<n; i++) Py_XDECREF(mps[i]);
     Py_XDECREF(ap);
     PyDataMem_FREE(mps);
-    _pya_free(sizes);
     PyArray_XDECREF_ERR(ret);
     return NULL;
 }

Modified: branches/ufunc_cleanup/numpy/core/tests/test_multiarray.py
===================================================================
--- branches/ufunc_cleanup/numpy/core/tests/test_multiarray.py	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/core/tests/test_multiarray.py	2008-10-05 18:29:19 UTC (rev 5928)
@@ -946,6 +946,26 @@
         assert repr(A) == reprA
 
 
+class TestChoose(TestCase):
+    def setUp(self):
+        self.x = 2*ones((3,),dtype=int)
+        self.y = 3*ones((3,),dtype=int)
+        self.x2 = 2*ones((2,3), dtype=int)
+        self.y2 = 3*ones((2,3), dtype=int)        
+        self.ind = [0,0,1]
 
+    def test_basic(self):
+        A = np.choose(self.ind, (self.x, self.y))
+        assert_equal(A, [2,2,3])
+
+    def test_broadcast1(self):
+        A = np.choose(self.ind, (self.x2, self.y2))
+        assert_equal(A, [[2,2,3],[2,2,3]])
+    
+    def test_broadcast2(self):
+        A = np.choose(self.ind, (self.x, self.y2))
+        assert_equal(A, [[2,2,3],[2,2,3]])
+        
+
 if __name__ == "__main__":
     run_module_suite()

Modified: branches/ufunc_cleanup/numpy/distutils/command/config.py
===================================================================
--- branches/ufunc_cleanup/numpy/distutils/command/config.py	2008-10-05 18:22:35 UTC (rev 5927)
+++ branches/ufunc_cleanup/numpy/distutils/command/config.py	2008-10-05 18:29:19 UTC (rev 5928)
@@ -125,7 +125,14 @@
         self._check_compiler()
         body = []
         if decl:
-            body.append("int %s ();" % func)
+            body.append("int %s (void);" % func)
+        # Handle MSVC intrisincs: force MS compiler to make a function call.
+        # Useful to test for some functions when built with optimization on, to
+        # avoid build error because the intrisinc and our 'fake' test
+        # declaration do not match.
+        body.append("#ifdef _MSC_VER")
+        body.append("#pragma function(%s)" % func)
+        body.append("#endif")
         body.append("int main (void) {")
         if call:
             if call_args is None:
@@ -140,6 +147,67 @@
         return self.try_link(body, headers, include_dirs,
                              libraries, library_dirs)
 
+    def check_funcs_once(self, funcs,
+                   headers=None, include_dirs=None,
+                   libraries=None, library_dirs=None,
+                   decl=False, call=False, call_args=None):
+        """Check a list of functions at once.
+
+        This is useful to speed up things, since all the functions in the funcs
+        list will be put in one compilation unit.
+
+        Arguments
+        ---------
+
+            funcs: seq
+                list of functions to test
+            include_dirs : seq
+                list of header paths
+            libraries : seq
+                list of libraries to link the code snippet to
+            libraru_dirs : seq
+                list of library paths
+            decl : dict
+                for every (key, value), the declaration in the value will be
+                used for function in key. If a function is not in the
+                dictionay, no declaration will be used.
+            call : dict
+                for every item (f, value), if the value is True, a call will be
+                done to the function f"""
+        self._check_compiler()
+        body = []
+        if decl:
+            for f, v in decl.items():
+                if v:
+                    body.append("int %s (void);" % f)
+
+        # Handle MS intrinsics. See check_func for more info.
+        body.append("#ifdef _MSC_VER")
+        for func in funcs:
+            body.append("#pragma function(%s)" % func)
+        body.append("#endif")
+
+        body.append("int main (void) {")
+        if call:
+            for f in funcs:
+                if call.has_key(f) and call[f]:
+                    if not (call_args and call_args.has_key(f) and call_args[f]):
+                        args = ''
+                    else:
+                        args = call_args[f]
+                    body.append("  %s(%s);" % (f, args))
+                else:
+                    body.append("  %s;" % f)
+        else:
+            for f in funcs:
+                body.append("  %s;" % f)
+        body.append("  return 0;")
+        body.append("}")
+        body = '\n'.join(body) + "\n"
+
+        return self.try_link(body, headers, include_dirs,
+                             libraries, library_dirs)
+
     def get_output(self, body, headers=None, include_dirs=None,
                    libraries=None, library_dirs=None,
                    lang="c"):



More information about the Numpy-svn mailing list