[Numpy-svn] r3406 - trunk/numpy/core/src

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Oct 27 11:56:35 CDT 2006


Author: oliphant
Date: 2006-10-27 11:56:32 -0500 (Fri, 27 Oct 2006)
New Revision: 3406

Modified:
   trunk/numpy/core/src/arrayobject.c
   trunk/numpy/core/src/scalartypes.inc.src
Log:
Allow subtypes of all array scalars and fix-up scalar_value to accept sub-types.

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2006-10-27 16:10:34 UTC (rev 3405)
+++ trunk/numpy/core/src/arrayobject.c	2006-10-27 16:56:32 UTC (rev 3406)
@@ -1432,6 +1432,7 @@
         }
         else {
                 destptr = scalar_value(obj, descr);
+                if (destptr == NULL) {Py_DECREF(obj); return NULL;}
         }
         /* copyswap for OBJECT increments the reference count */
         copyswap(destptr, data, swap, base);

Modified: trunk/numpy/core/src/scalartypes.inc.src
===================================================================
--- trunk/numpy/core/src/scalartypes.inc.src	2006-10-27 16:10:34 UTC (rev 3405)
+++ trunk/numpy/core/src/scalartypes.inc.src	2006-10-27 16:56:32 UTC (rev 3406)
@@ -31,7 +31,7 @@
 static void *
 scalar_value(PyObject *scalar, PyArray_Descr *descr)
 {
-        enum NPY_TYPES type_num;
+        int type_num;
         if (descr == NULL) {
                 descr = PyArray_DescrFromScalar(scalar);
                 type_num = descr->type_num;
@@ -63,10 +63,59 @@
         case NPY_STRING: return (void *)PyString_AS_STRING(scalar);
         case NPY_UNICODE: return (void *)PyUnicode_AS_DATA(scalar);
         case NPY_VOID: return ((PyVoidScalarObject *)scalar)->obval;
-        default:
-                return NULL;
         }
+
+        /* Must be a user-defined type --- check to see which
+           scalar it inherits from. */
+        
+#define _CHK(cls) (PyObject_IsInstance(scalar, \
+                                       (PyObject *)&Py##cls##ArrType_Type))
+#define _OBJ(lt) &(((Py##lt##ScalarObject *)scalar)->obval)
+#define _IFCASE(cls) if _CHK(cls) return _OBJ(cls)
+
+        if _CHK(Number) {
+                if _CHK(Integer) {
+                        if _CHK(SignedInteger) {
+                                _IFCASE(Byte);
+                                _IFCASE(Short);
+                                _IFCASE(Int);
+                                _IFCASE(Long);
+                                _IFCASE(LongLong);
+                        }
+                        else { /* Unsigned Integer */
+                                _IFCASE(UByte);
+                                _IFCASE(UShort);
+                                _IFCASE(UInt);
+                                _IFCASE(ULong);
+                                _IFCASE(ULongLong);
+                        }
+                }
+                else { /* Inexact */
+                        if _CHK(Floating) {
+                                _IFCASE(Float);
+                                _IFCASE(Double);
+                                _IFCASE(LongDouble);
+                        }
+                        else { /*ComplexFloating */
+                                _IFCASE(CFloat);
+                                _IFCASE(CDouble);
+                                _IFCASE(CLongDouble);
+                        }
+                }
+        }
+        else if _CHK(Bool) return _OBJ(Bool);
+        else if _CHK(Flexible) {
+                if _CHK(String) return (void *)PyString_AS_STRING(scalar);
+                if _CHK(Unicode) return (void *)PyUnicode_AS_DATA(scalar);
+                if _CHK(Void) return ((PyVoidScalarObject *)scalar)->obval;
+        }
+        else _IFCASE(Object);
+        
+        PyErr_SetString(PyExc_RuntimeError, "bad scalar");
         return NULL;
+#undef _IFCASE
+#undef _OBJ
+#undef _CHK
 }
 
 /* no error checking is performed -- ctypeptr must be same type as scalar */
@@ -143,10 +192,12 @@
                          void *ctypeptr, int outtype)
 {
         PyArray_VectorUnaryFunc* castfunc;
+        void *ptr;
         castfunc = PyArray_GetCastFunc(indescr, outtype);
         if (castfunc == NULL) return -1;
-        castfunc(scalar_value(scalar, indescr),
-                 ctypeptr, 1, NULL, NULL);
+        ptr = scalar_value(scalar, indescr);
+        if (ptr == NULL) return -1;
+        castfunc(ptr, ctypeptr, 1, NULL, NULL);
         return 0;
 }
 
@@ -191,21 +242,12 @@
 
 
         memptr = scalar_value(scalar, typecode);
-        
         if (memptr == NULL) {
-                if (PyDataType_ISUSERDEF(typecode)) {
-                        /* Use setitem to set array from scalar */
-                        if (typecode->f->setitem(scalar, 
-                                                 PyArray_DATA(r), r) < 0) {
-                                Py_XDECREF(outcode);
-                                Py_DECREF(r);
-                                return NULL;
-                        }
-                }
-                PyErr_SetString(PyExc_ValueError, "invalid scalar");
+                Py_XDECREF(outcode);
+                Py_DECREF(r);
                 return NULL;
         }
-
+        
 #ifndef Py_UNICODE_WIDE
         if (typecode->type_num == PyArray_UNICODE) {
                 PyUCS2Buffer_AsUCS4((Py_UNICODE *)memptr,
@@ -827,9 +869,11 @@
         int typenum;
 
         if (PyArray_IsScalar(self, ComplexFloating)) {
+                void *ptr;
                 typecode = _realdescr_fromcomplexscalar(self, &typenum);
-                ret = PyArray_Scalar(scalar_value(self, NULL),
-                                     typecode, NULL);
+                ptr = scalar_value(self, NULL);
+                if (ptr == NULL) {Py_DECREF(typecode); return NULL;}
+                ret = PyArray_Scalar(ptr, typecode, NULL);
                 Py_DECREF(typecode);
                 return ret;
         }
@@ -851,9 +895,15 @@
         int typenum;
 
         if (PyArray_IsScalar(self, ComplexFloating)) {
+                char *ptr;
                 typecode = _realdescr_fromcomplexscalar(self, &typenum);
-                ret = PyArray_Scalar((char *)scalar_value(self, NULL)
-                                 + typecode->elsize, typecode, NULL);
+                ptr = (char *)scalar_value(self, NULL);
+                if (ptr == NULL) {
+                        Py_DECREF(typecode);
+                        return NULL;
+                }
+                ret = PyArray_Scalar(ptr + typecode->elsize,
+                                     typecode, NULL);
         }
         else if (PyArray_IsScalar(self, Object)) {
                 PyObject *obj = ((PyObjectScalarObject *)self)->obval;
@@ -1102,9 +1152,12 @@
         if (PyArray_IsScalar(ret, Generic) &&   \
             (!PyArray_IsScalar(ret, Void))) {
                 PyArray_Descr *new;
+                void *ptr;
                 if (!PyArray_ISNBO(self->descr->byteorder)) {
                         new = PyArray_DescrFromScalar(ret);
-                        byte_swap_vector(scalar_value(ret, new), 1, new->elsize);
+                        ptr = scalar_value(ret, new);
+                        if (ptr == NULL) {Py_DECREF(new); return NULL;}
+                        byte_swap_vector(ptr, 1, new->elsize);
                         Py_DECREF(new);
                 }
         }
@@ -2366,12 +2419,10 @@
 #name=bool, byte, short, int, long, longlong, ubyte, ushort, uint, ulong, ulonglong, float, double, longdouble, cfloat, cdouble, clongdouble, string, unicode, void, object#
 #NAME=Bool, Byte, Short, Int, Long, LongLong, UByte, UShort, UInt, ULong, ULongLong, Float, Double, LongDouble, CFloat, CDouble, CLongDouble, String, Unicode, Void, Object#
         */
-        Py at NAME@ArrType_Type.tp_flags = LEAFFLAGS;
+        Py at NAME@ArrType_Type.tp_flags = BASEFLAGS;
         Py at NAME@ArrType_Type.tp_new = @name at _arrtype_new;
         Py at NAME@ArrType_Type.tp_richcompare = gentype_richcompare;
         /**end repeat**/
-        /* Allow the Void type to be subclassed -- for adding new types */
-        PyVoidArrType_Type.tp_flags = BASEFLAGS;
 
         /**begin repeat
 #name=bool, byte, short, ubyte, ushort, uint, ulong, ulonglong, float, longdouble, cfloat, clongdouble, void, object#



More information about the Numpy-svn mailing list