[Numpy-svn] r3016 - in trunk/numpy/core: . src

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Aug 14 15:01:25 CDT 2006


Author: oliphant
Date: 2006-08-14 15:01:12 -0500 (Mon, 14 Aug 2006)
New Revision: 3016

Modified:
   trunk/numpy/core/defchararray.py
   trunk/numpy/core/numeric.py
   trunk/numpy/core/src/arrayobject.c
   trunk/numpy/core/src/multiarraymodule.c
   trunk/numpy/core/src/ufuncobject.c
Log:
Strip characters from chararrays during comparision

Modified: trunk/numpy/core/defchararray.py
===================================================================
--- trunk/numpy/core/defchararray.py	2006-08-14 10:10:55 UTC (rev 3015)
+++ trunk/numpy/core/defchararray.py	2006-08-14 20:01:12 UTC (rev 3016)
@@ -1,5 +1,5 @@
 from numerictypes import string_, unicode_, integer, object_
-from numeric import ndarray, broadcast, empty
+from numeric import ndarray, broadcast, empty, compare_chararrays
 from numeric import array as narray
 import sys
 
@@ -12,7 +12,8 @@
 # This adds + and * operations and methods of str and unicode types
 #  which operate on an element-by-element basis
 
-# It also strips white-space on element retrieval
+# It also strips white-space on element retrieval and on
+#   comparisons
 
 class chararray(ndarray):
     def __new__(subtype, shape, itemsize=1, unicode=False, buffer=None,
@@ -44,9 +45,31 @@
     def __getitem__(self, obj):
         val = ndarray.__getitem__(self, obj)
         if isinstance(val, (string_, unicode_)):
-            return val.rstrip()
+            temp = val.rstrip()
+            if len(temp) == 0:
+                val = val[0]
+            else:
+                val = temp
         return val
 
+    def __eq__(self, other): 
+        return compare_chararrays(self, other, '==', True)
+
+    def __ne__(self, other): 
+        return compare_chararrays(self, other, '!=', True)
+
+    def __ge__(self, other):
+        return compare_chararrays(self, other, '>=', True)        
+
+    def __le__(self, other):
+        return compare_chararrays(self, other, '<=', True)                
+
+    def __gt__(self, other):
+        return compare_chararrays(self, other, '>', True)
+
+    def __lt__(self, other):
+        return compare_chararrays(self, other, '<', True)        
+
     def __add__(self, other):
         b = broadcast(self, other)
         arr = b.iters[1].base

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2006-08-14 10:10:55 UTC (rev 3015)
+++ trunk/numpy/core/numeric.py	2006-08-14 20:01:12 UTC (rev 3016)
@@ -14,7 +14,7 @@
            'fromiter', 'array_equal', 'array_equiv',
            'indices', 'fromfunction',
            'load', 'loads', 'isscalar', 'binary_repr', 'base_repr',
-           'ones', 'identity', 'allclose',
+           'ones', 'identity', 'allclose', 'compare_chararrays',
            'seterr', 'geterr', 'setbufsize', 'getbufsize',
            'seterrcall', 'geterrcall', 'flatnonzero',
            'Inf', 'inf', 'infty', 'Infinity',
@@ -119,6 +119,7 @@
 set_numeric_ops = multiarray.set_numeric_ops
 can_cast = multiarray.can_cast
 lexsort = multiarray.lexsort
+compare_chararrays = multiarray.compare_chararrays
 
 
 def asarray(a, dtype=None, order=None):

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2006-08-14 10:10:55 UTC (rev 3015)
+++ trunk/numpy/core/src/arrayobject.c	2006-08-14 20:01:12 UTC (rev 3016)
@@ -4183,9 +4183,130 @@
         return 0;
 }
 
+/* Borrowed from Numarray */
+
+#define SMALL_STRING 2048
+
+#if defined(isspace)
+#undef isspace
+#define isspace(c)  ((c==' ')||(c=='\t')||(c=='\n')||(c=='\r')||(c=='\v')||(c=='\f'))
+#endif
+
+static void _rstripw(char *s, int n)
+{
+        int i;
+        for(i=strnlen(s,n)-1; i>=1; i--)  /* Never strip to length 0. */
+        {
+                int c = s[i];
+                if (!c || isspace(c))
+                        s[i] = 0;
+                else
+                        break;
+        }
+}
+
+static void _unistripw(PyArray_UCS4 *s, int n)
+{
+        int i;
+        for(i=n-1; i>=1; i--)  /* Never strip to length 0. */
+        {
+                PyArray_UCS4 c = s[i];
+                if (!c || isspace(c))
+                        s[i] = 0;
+                else
+                        break;
+        }
+}
+
+
+static char *
+_char_copy_n_strip(char *original, char *temp, int nc)
+{
+        if (nc > SMALL_STRING) {
+                temp = malloc(nc);
+                if (!temp) {
+                        PyErr_NoMemory();
+                        return NULL;
+                }
+        }
+        memcpy(temp, original, nc);
+        _rstripw(temp, nc);
+        return temp;
+}
+
+static void
+_char_release(char *ptr, int nc)
+{
+        if (nc > SMALL_STRING) {
+                free(ptr);
+        }
+}
+
+static char *
+_uni_copy_n_strip(char *original, char *temp, int nc)
+{
+        if (nc*4 > SMALL_STRING) {
+                temp = malloc(nc);
+                if (!temp) {
+                        PyErr_NoMemory();
+                        return NULL;
+                }
+        }
+        memcpy(temp, original, nc*sizeof(PyArray_UCS4));
+        _unistripw((PyArray_UCS4 *)temp, nc);
+        return temp;
+}
+
+static void
+_uni_release(char *ptr, int nc)
+{
+        if (nc*sizeof(PyArray_UCS4) > SMALL_STRING) {
+                free(ptr);
+        }
+}
+
+
+/* End borrowed from numarray */
+
+#define _rstrip_loop(CMP) {                     \
+                void *aptr, *bptr; \
+                char atemp[SMALL_STRING], btemp[SMALL_STRING]; \
+                while(size--) { \
+                        aptr = stripfunc(iself->dataptr, atemp, N1); \
+                        if (!aptr) return -1; \
+                        bptr = stripfunc(iother->dataptr, btemp, N2); \
+                        if (!bptr) { \
+                                relfunc(aptr, N1); \
+                                return -1; \
+                        } \
+                        val = cmpfunc(aptr, bptr, N1, N2); \
+                        *dptr = (val CMP 0); \
+                        PyArray_ITER_NEXT(iself); \
+                        PyArray_ITER_NEXT(iother); \
+                        dptr += 1; \
+                        relfunc(aptr, N1); \
+                        relfunc(bptr, N2); \
+                } \
+        }
+
+#define _reg_loop(CMP) { \
+                while(size--) {                                 \
+                        val = cmpfunc((void *)iself->dataptr,   \
+                                      (void *)iother->dataptr,  \
+                                      N1, N2);                  \
+                        *dptr = (val CMP 0);                    \
+                        PyArray_ITER_NEXT(iself);               \
+                        PyArray_ITER_NEXT(iother);              \
+                        dptr += 1;                              \
+                } \
+        }
+
+#define _loop(CMP) if (rstrip) _rstrip_loop(CMP) \
+        else _reg_loop(CMP)
+
 static int
 _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, 
-                 int cmp_op, void *func)
+                 int cmp_op, void *func, int rstrip)
 {
         PyArrayIterObject *iself, *iother;
         Bool *dptr;
@@ -4193,6 +4314,8 @@
         int val;
         int N1, N2;
         int (*cmpfunc)(void *, void *, int, int);
+        void (*relfunc)(char *, int);
+        char* (*stripfunc)(char *, char *, int);
         
         cmpfunc = func;
         dptr = (Bool *)PyArray_DATA(result);
@@ -4204,43 +4327,48 @@
         if ((void *)cmpfunc == (void *)_myunincmp) {
                 N1 >>= 2;
                 N2 >>= 2;
+                stripfunc = _uni_copy_n_strip;
+                relfunc = _uni_release;
         }
-        while(size--) {
-                val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr, 
-                              N1, N2);
-                switch (cmp_op) {
-                case Py_EQ:
-                        *dptr = (val == 0);
-                        break;
-                case Py_NE:
-                        *dptr = (val != 0);
-                        break;
-                case Py_LT:
-                        *dptr = (val < 0);
-                        break;
-                case Py_LE:
-                        *dptr = (val <= 0);
-                        break;
-                case Py_GT:
-                        *dptr = (val > 0);
-                        break;
-                case Py_GE:
-                        *dptr = (val >= 0);
-                        break;
-                default:
-                        PyErr_SetString(PyExc_RuntimeError,
-                                        "bad comparison operator");
-                        return -1;
-                }
-                PyArray_ITER_NEXT(iself);
-                PyArray_ITER_NEXT(iother);
-                dptr += 1;
+        else {
+                stripfunc = _char_copy_n_strip;
+                relfunc = _char_release;
         }
+        switch (cmp_op) {
+        case Py_EQ:
+                _loop(==)
+                break;
+        case Py_NE:
+                _loop(!=)
+                break;                
+        case Py_LT:
+                _loop(<)
+                break;
+        case Py_LE:
+                _loop(<=)
+                break;
+        case Py_GT:
+                _loop(>)
+                break;
+        case Py_GE:
+                _loop(>=)
+                break;
+        default:
+                PyErr_SetString(PyExc_RuntimeError,
+                                "bad comparison operator");
+                return -1;
+        }
         return 0;       
 }
 
+#undef _loop
+#undef _reg_loop
+#undef _rstrip_loop
+#undef SMALL_STRING
+
 static PyObject *
-_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
+_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
+                     int rstrip)
 {
         PyObject *result;
         PyArrayMultiIterObject *mit;
@@ -4294,10 +4422,12 @@
         if (result == NULL) goto finish;
 
         if (self->descr->type_num == PyArray_UNICODE) {
-                val = _compare_strings(result, mit, cmp_op, _myunincmp);
+                val = _compare_strings(result, mit, cmp_op, _myunincmp, 
+                                       rstrip);
         }
         else {
-                val = _compare_strings(result, mit, cmp_op, _mystrncmp);
+                val = _compare_strings(result, mit, cmp_op, _mystrncmp, 
+                                       rstrip);
         }
         
         if (val < 0) {Py_DECREF(result); result = NULL;}
@@ -4361,7 +4491,7 @@
         }
         else { /* compare as a string */
                 /* assumes self and other have same descr->type */
-                return _strings_richcompare(self, other, cmp_op);
+                return _strings_richcompare(self, other, cmp_op, 0);
         }
 }
 
@@ -4531,7 +4661,7 @@
                 if (PyArray_ISSTRING(self) && PyArray_ISSTRING(array_other)) {
                         Py_DECREF(result);
                         result = _strings_richcompare(self, (PyArrayObject *)
-                                                      array_other, cmp_op);
+                                                      array_other, cmp_op, 0);
                 }
                 Py_DECREF(array_other);
         }

Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c	2006-08-14 10:10:55 UTC (rev 3015)
+++ trunk/numpy/core/src/multiarraymodule.c	2006-08-14 20:01:12 UTC (rev 3016)
@@ -6357,6 +6357,71 @@
     return PyString_FromString(repr);
 }
 
+static PyObject *
+compare_chararrays(PyObject *dummy, PyObject *args, PyObject *kwds)
+{
+    PyObject *array;
+    PyObject *other;
+    PyArrayObject *newarr, *newoth;
+    int cmp_op;
+    Bool rstrip;
+    char *cmp_str;
+    Py_ssize_t strlen;
+    PyObject *res=NULL;
+    static char msg[] = \
+            "comparision must be '==', '!=', '<', '>', '<=', '>='";
+
+    static char *kwlist[] = {"a1", "a2", "cmp", "rstrip", NULL};
+
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOs#O&", kwlist,
+                                     &array, &other, 
+                                     &cmp_str, &strlen,
+                                     PyArray_BoolConverter, &rstrip))
+            return NULL;
+
+    if (strlen < 1 || strlen > 2) goto err;
+    if (strlen > 1) {
+            if (cmp_str[1] != '=') goto err;
+            if (cmp_str[0] == '=') cmp_op = Py_EQ;
+            else if (cmp_str[0] == '!') cmp_op = Py_NE;
+            else if (cmp_str[0] == '<') cmp_op = Py_LE;
+            else if (cmp_str[0] == '>') cmp_op = Py_GE;
+            else goto err;
+    }
+    else {
+            if (cmp_str[0] == '<') cmp_op = Py_LT;
+            else if (cmp_str[0] == '>') cmp_op = Py_GT;
+            else goto err;
+    }
+    
+    newarr = (PyArrayObject *)PyArray_FROM_O(array);
+    if (newarr == NULL) return NULL;
+    newoth = (PyArrayObject *)PyArray_FROM_O(other);
+    if (newoth == NULL) {
+            Py_DECREF(newarr);
+            return NULL;
+    }    
+    
+    if (PyArray_ISSTRING(newarr) && PyArray_ISSTRING(newoth)) {
+            res = _strings_richcompare(newarr, newoth, cmp_op, rstrip != 0);
+    }
+    else {
+            PyErr_SetString(PyExc_TypeError, 
+                            "comparison of non-string arrays");
+    }
+
+    Py_DECREF(newarr);
+    Py_DECREF(newoth);
+    return res;
+
+ err:
+    PyErr_SetString(PyExc_ValueError, msg);
+    return NULL;
+}
+
+
+
+
 static struct PyMethodDef array_module_methods[] = {
 	{"_get_ndarray_c_version", (PyCFunction)array__get_ndarray_c_version, 
 	 METH_VARARGS|METH_KEYWORDS, NULL},
@@ -6409,6 +6474,8 @@
 	 METH_VARARGS | METH_KEYWORDS, NULL},
         {"format_longfloat", (PyCFunction)format_longfloat,
          METH_VARARGS | METH_KEYWORDS, NULL},
+        {"compare_chararrays", (PyCFunction)compare_chararrays,
+         METH_VARARGS | METH_KEYWORDS, NULL},
 	{NULL,		NULL, 0}		/* sentinel */
 };
 

Modified: trunk/numpy/core/src/ufuncobject.c
===================================================================
--- trunk/numpy/core/src/ufuncobject.c	2006-08-14 10:10:55 UTC (rev 3015)
+++ trunk/numpy/core/src/ufuncobject.c	2006-08-14 20:01:12 UTC (rev 3016)
@@ -1761,7 +1761,7 @@
         PyUFuncReduceObject *loop;
         PyArrayObject *idarr;
 	PyArrayObject *aar;
-        intp loop_i[MAX_DIMS], outsize;
+        intp loop_i[MAX_DIMS], outsize=0;
         int arg_types[3] = {otype, otype, otype};
 	PyArray_SCALARKIND scalars[3] = {PyArray_NOSCALAR, PyArray_NOSCALAR, 
 					 PyArray_NOSCALAR};



More information about the Numpy-svn mailing list