[Numpy-svn] r4805 - in trunk/numpy/core: src tests

numpy-svn@scip... numpy-svn@scip...
Thu Feb 14 01:24:23 CST 2008


Author: charris
Date: 2008-02-14 01:24:13 -0600 (Thu, 14 Feb 2008)
New Revision: 4805

Modified:
   trunk/numpy/core/src/_sortmodule.c.src
   trunk/numpy/core/tests/test_multiarray.py
Log:
Add type specific mergesort for strings and unicode.
Use adaptive string_copy function suggested by Francesc.
Make the sorting code more consistent across types.


Modified: trunk/numpy/core/src/_sortmodule.c.src
===================================================================
--- trunk/numpy/core/src/_sortmodule.c.src	2008-02-14 01:32:35 UTC (rev 4804)
+++ trunk/numpy/core/src/_sortmodule.c.src	2008-02-14 07:24:13 UTC (rev 4805)
@@ -31,6 +31,7 @@
 #define PYA_QS_STACK 100
 #define SMALL_QUICKSORT 15
 #define SMALL_MERGESORT 20
+#define SMALL_STRING 16
 #define SWAP(a,b) {SWAP_temp = (b); (b)=(a); (a) = SWAP_temp;}
 #define STDC_LT(a,b) ((a) < (b))
 #define STDC_LE(a,b) ((a) <= (b))
@@ -58,26 +59,27 @@
     @type@ *pl = start;
     @type@ *pr = start + num - 1;
     @type@ vp, SWAP_temp;
-    @type@ *stack[PYA_QS_STACK], **sptr = stack, *pm, *pi, *pj, *pt;
+    @type@ *stack[PYA_QS_STACK], **sptr = stack, *pm, *pi, *pj, *pk;
 
     for(;;) {
         while ((pr - pl) > SMALL_QUICKSORT) {
             /* quicksort partition */
             pm = pl + ((pr - pl) >> 1);
-            if (@lessthan@(*pm,*pl)) SWAP(*pm,*pl);
-            if (@lessthan@(*pr,*pm)) SWAP(*pr,*pm);
-            if (@lessthan@(*pm,*pl)) SWAP(*pm,*pl);
+            if (@lessthan@(*pm, *pl)) SWAP(*pm, *pl);
+            if (@lessthan@(*pr, *pm)) SWAP(*pr, *pm);
+            if (@lessthan@(*pm, *pl)) SWAP(*pm, *pl);
             vp = *pm;
             pi = pl;
             pj = pr - 1;
-            SWAP(*pm,*pj);
+            SWAP(*pm, *pj);
             for(;;) {
-                do ++pi; while (@lessthan@(*pi,vp));
-                do --pj; while (@lessthan@(vp,*pj));
+                do ++pi; while (@lessthan@(*pi, vp));
+                do --pj; while (@lessthan@(vp, *pj));
                 if (pi >= pj)  break;
                 SWAP(*pi,*pj);
             }
-            SWAP(*pi,*(pr-1));
+            pk = pr - 1;
+            SWAP(*pi, *pk);
             /* push largest partition on stack */
             if (pi - pl < pr - pi) {
                 *sptr++ = pi + 1;
@@ -94,8 +96,10 @@
         /* insertion sort */
         for(pi = pl + 1; pi <= pr; ++pi) {
             vp = *pi;
-            for(pj = pi, pt = pi - 1; pj > pl && @lessthan@(vp, *pt);) {
-                *pj-- = *pt--;
+            pj = pi;
+            pk = pi - 1;
+            while (pj > pl && @lessthan@(vp, *pk)) {
+                *pj-- = *pk--;
             }
             *pj = vp;
         }
@@ -112,7 +116,7 @@
 {
     @type@ vp;
     intp *pl, *pr, SWAP_temp;
-    intp *stack[PYA_QS_STACK], **sptr=stack, *pm, *pi, *pj, *pt, vi;
+    intp *stack[PYA_QS_STACK], **sptr=stack, *pm, *pi, *pj, *pk, vi;
 
     pl = tosort;
     pr = tosort + num - 1;
@@ -134,7 +138,8 @@
                 if (pi >= pj)  break;
                 SWAP(*pi,*pj);
             }
-            SWAP(*pi,*(pr-1));
+            pk = pr - 1; 
+            SWAP(*pi,*pk);
             /* push largest partition on stack */
             if (pi - pl < pr - pi) {
                 *sptr++ = pi + 1;
@@ -152,8 +157,10 @@
         for(pi = pl + 1; pi <= pr; ++pi) {
             vi = *pi;
             vp = v[vi];
-            for(pj = pi, pt = pi - 1; pj > pl && @lessthan@(vp, v[*pt]);) {
-                *pj-- = *pt--;
+            pj = pi;
+            pk = pi - 1;
+            while (pj > pl && @lessthan@(vp, v[*pk])) {
+                *pj-- = *pk--;
             }
             *pj = vi;
         }
@@ -263,32 +270,35 @@
 
     if (pr - pl > SMALL_MERGESORT) {
         /* merge sort */
-        pm = pl + ((pr - pl + 1)>>1);
-        @TYPE@_mergesort0(pl,pm-1,pw);
-        @TYPE@_mergesort0(pm,pr,pw);
-        for(pi = pw, pj = pl; pj < pm; ++pi, ++pj) {
-            *pi = *pj;
+        pm = pl + ((pr - pl) >> 1);
+        @TYPE@_mergesort0(pl, pm, pw);
+        @TYPE@_mergesort0(pm, pr, pw);
+        for(pi = pw, pj = pl; pj < pm;) {
+            *pi++ = *pj++;
         }
-        for(pk = pw, pm = pl; pk < pi && pj <= pr; ++pm) {
-            if (@lessequal@(*pk,*pj)) {
-                *pm = *pk;
-                ++pk;
+        pj = pw;
+        pk = pl;
+        while (pj < pi && pm < pr) {
+            if (@lessequal@(*pj,*pm)) {
+                *pk = *pj++;
             }
             else {
-                *pm = *pj;
-                ++pj;
+                *pk = *pm++;
             }
+            pk++;
         }
-        for(; pk < pi; ++pm, ++pk) {
-            *pm = *pk;
+        while(pj < pi) {
+            *pk++ = *pj++;
         }
     }
     else {
         /* insertion sort */
-        for(pi = pl + 1; pi <= pr; ++pi) {
+        for(pi = pl + 1; pi < pr; ++pi) {
             vp = *pi;
-            for(pj = pi, pk = pi - 1; pj > pl && @lessthan@(vp, *pk); --pj, --pk) {
-                *pj = *pk;
+            pj = pi;
+            pk = pi -1;
+            while (pj > pl && @lessthan@(vp, *pk)) {
+                *pj-- = *pk--;
             }
             *pj = vp;
         }
@@ -300,17 +310,16 @@
 {
     @type@ *pl, *pr, *pw;
 
-    pl = start; pr = pl + num - 1;
-    pw = (@type@ *) PyDataMem_NEW(((1+num/2))*sizeof(@type@));
-
+    pl = start;
+    pr = pl + num;
+    pw = (@type@ *) PyDataMem_NEW((num/2)*sizeof(@type@));
     if (!pw) {
         PyErr_NoMemory();
         return -1;
     }
+    @TYPE@_mergesort0(pl, pr, pw);
 
-    @TYPE@_mergesort0(pl, pr, pw);
     PyDataMem_FREE(pw);
-
     return 0;
 }
 
@@ -398,9 +407,14 @@
 static void
 copy_string(char *s1, char *s2, size_t len)
 {
-    while(len--) {
-        *s1++ = *s2++;
+    if (len < SMALL_STRING) {
+        while(len--) {
+            *s1++ = *s2++;
+        }
     }
+    else {
+        memcpy(s1, s2, len);
+    }
 }
 
 static void
@@ -456,7 +470,81 @@
    #copy=copy_string, copy_ucs4#
 **/
 
+static void
+@TYPE@_mergesort0(@type@ *pl, @type@ *pr, @type@ *pw, @type@ *vp, size_t len)
+{
+    @type@ *pi, *pj, *pk, *pm;
+
+    if (pr - pl > SMALL_MERGESORT*len) {
+        /* merge sort */
+        pm = pl + (((pr - pl)/len) >> 1)*len;
+        @TYPE@_mergesort0(pl, pm, pw, vp, len);
+        @TYPE@_mergesort0(pm, pr, pw, vp, len);
+        @copy@(pw, pl, pm - pl);
+        pi = pw + (pm - pl);
+        pj = pw;
+        pk = pl;
+        while (pj < pi && pm < pr) {
+            if (@lessequal@(pj, pm, len)) {
+                @copy@(pk, pj, len);
+                pj += len;
+            }
+            else {
+                @copy@(pk, pm, len);
+                pm += len;
+            }
+            pk += len;
+        }
+        @copy@(pk, pj, pi - pj);
+    }
+    else {
+        /* insertion sort */
+        for(pi = pl + len; pi < pr; pi += len) {
+            @copy@(vp, pi, len);
+            pj = pi;
+            pk = pi - len;
+            while (pj > pl && @lessthan@(vp, pk, len)) {
+                @copy@(pj, pk, len);
+                pj -= len;
+                pk -= len;
+            }
+            @copy@(pj, vp, len);
+        }
+    }
+}
+
 static int
+@TYPE@_mergesort(@type@ *start, intp num, PyArrayObject *arr)
+{
+    const size_t elsize = arr->descr->elsize;
+    const size_t len = elsize / sizeof(@type@);
+    @type@ *pl, *pr, *pw, *vp;
+    int err = 0;
+
+    pl = start;
+    pr = pl + num*len;
+    pw = (@type@ *) PyDataMem_NEW((num/2)*elsize);
+    if (!pw) {
+        PyErr_NoMemory();
+        err = -1;
+        goto fail_0;
+    }
+    vp = (@type@ *) PyDataMem_NEW(elsize);
+    if (!vp) {
+        PyErr_NoMemory();
+        err = -1;
+        goto fail_1;
+    }
+    @TYPE@_mergesort0(pl, pr, pw, vp, len);
+
+    PyDataMem_FREE(vp);
+fail_1:
+    PyDataMem_FREE(pw);
+fail_0:
+    return err;
+}
+
+static int
 @TYPE@_quicksort(@type@ *start, intp num, PyArrayObject *arr)
 {
     const size_t len = arr->descr->elsize/sizeof(@type@);
@@ -466,7 +554,7 @@
     @type@ *stack[PYA_QS_STACK], **sptr = stack, *pm, *pi, *pj, *pk;
 
     for(;;) {
-        while ((pr - pl) > 5*len) {
+        while ((pr - pl) > SMALL_QUICKSORT*len) {
             /* quicksort partition */
             pm = pl + (((pr - pl)/len) >> 1)*len;
             if (@lessthan@(pm, pl, len)) @swap@(pm, pl, len);
@@ -502,7 +590,7 @@
             @copy@(vp, pi, len);
             pj = pi;
             pk = pi - len;
-            while(pj > pl && @lessthan@(vp, pk, len)) {
+            while (pj > pl && @lessthan@(vp, pk, len)) {
                 @copy@(pj, pk, len);
                 pj -= len;
                 pk -= len;
@@ -641,7 +729,8 @@
                 if (pi >= pj)  break;
                 SWAP(*pi,*pj);
             }
-            SWAP(*pi,*(pr-1));
+            pk = pr - 1;
+            SWAP(*pi,*pk);
             /* push largest partition on stack */
             if (pi - pl < pr - pi) {
                 *sptr++ = pi + 1;
@@ -660,8 +749,8 @@
             vi = *pi;
             vp = v + vi*len;
             pj = pi;
-            pk = pi -1;
-            while(pj > pl && @lessthan@(vp, v + (*pk)*len, len)) {
+            pk = pi - 1;
+            while (pj > pl && @lessthan@(vp, v + (*pk)*len, len)) {
                 *pj-- = *pk--;
             }
             *pj = vi;
@@ -683,30 +772,33 @@
 
     if (pr - pl > SMALL_MERGESORT) {
         /* merge sort */
-        pm = pl + ((pr - pl + 1)>>1);
-        @TYPE@_amergesort0(pl,pm-1,v,pw,len);
+        pm = pl + ((pr - pl) >> 1);
+        @TYPE@_amergesort0(pl,pm,v,pw,len);
         @TYPE@_amergesort0(pm,pr,v,pw,len);
         for(pi = pw, pj = pl; pj < pm;) {
             *pi++ = *pj++;
         }
-        for(pk = pw, pm = pl; pk < pi && pj <= pr;) {
-            if (@lessequal@(v + (*pk)*len, v + (*pj)*len, len)) {
-                *pm++ = *pk++;
+        pj = pw;
+        pk = pl;
+        while (pj < pi && pm < pr) {
+            if (@lessequal@(v + (*pj)*len, v + (*pm)*len, len)) {
+                *pk = *pj++;
             } else {
-                *pm++ = *pj++;
+                *pk = *pm++;
             }
+            pk++;
         }
-        while(pk < pi) {
-            *pm++ = *pk++;
+        while (pj < pi) {
+            *pk++ = *pj++;
         }
     } else {
         /* insertion sort */
-        for(pi = pl + 1; pi <= pr; ++pi) {
+        for(pi = pl + 1; pi < pr; ++pi) {
             vi = *pi;
             vp = v + vi*len;
             pj = pi;
             pk = pi -1;
-            while(pj > pl && @lessthan@(vp, v + (*pk)*len, len)) {
+            while (pj > pl && @lessthan@(vp, v + (*pk)*len, len)) {
                 *pj-- = *pk--;
             }
             *pj = vi;
@@ -718,24 +810,20 @@
 static int
 @TYPE@_amergesort(@type@ *v, intp *tosort, intp num, PyArrayObject *arr)
 {
+    const size_t elsize = arr->descr->elsize;
+    const size_t len = elsize / sizeof(@type@);
     intp *pl, *pr, *pw;
-    int elsize, chars;
 
-    elsize = arr->descr->elsize;
-
-    chars = elsize / sizeof(@type@);
-
-    pl = tosort; pr = pl + num - 1;
-    pw = PyDimMem_NEW((1+num/2));
-
+    pl = tosort;
+    pr = pl + num;
+    pw = PyDimMem_NEW(num/2);
     if (!pw) {
         PyErr_NoMemory();
         return -1;
     }
+    @TYPE@_amergesort0(pl, pr, v, pw, len);
 
-    @TYPE@_amergesort0(pl, pr, v, pw, chars);
     PyDimMem_FREE(pw);
-
     return 0;
 }
 /**end repeat**/
@@ -746,7 +834,7 @@
     PyArray_Descr *descr;
 
     /**begin repeat
-       #TYPE=BOOL,BYTE,UBYTE,SHORT,USHORT,INT,UINT,LONG,ULONG,LONGLONG,ULONGLONG,FLOAT,DOUBLE,LONGDOUBLE,CFLOAT,CDOUBLE,CLONGDOUBLE#
+       #TYPE=BOOL,BYTE,UBYTE,SHORT,USHORT,INT,UINT,LONG,ULONG,LONGLONG,ULONGLONG,FLOAT,DOUBLE,LONGDOUBLE,CFLOAT,CDOUBLE,CLONGDOUBLE,STRING,UNICODE#
     **/
     descr = PyArray_DescrFromType(PyArray_@TYPE@);
     descr->f->sort[PyArray_QUICKSORT] = \
@@ -763,29 +851,6 @@
         (PyArray_ArgSortFunc *)@TYPE@_amergesort;
     /**end repeat**/
 
-    descr = PyArray_DescrFromType(PyArray_STRING);
-    descr->f->argsort[PyArray_MERGESORT] = \
-        (PyArray_ArgSortFunc *)STRING_amergesort;
-    descr->f->argsort[PyArray_QUICKSORT] = \
-        (PyArray_ArgSortFunc *)STRING_aquicksort;
-    descr->f->argsort[PyArray_HEAPSORT] = \
-        (PyArray_ArgSortFunc *)STRING_aheapsort;
-    descr->f->sort[PyArray_QUICKSORT] = \
-        (PyArray_SortFunc *)STRING_quicksort;
-    descr->f->sort[PyArray_HEAPSORT] = \
-        (PyArray_SortFunc *)STRING_heapsort;
-
-    descr = PyArray_DescrFromType(PyArray_UNICODE);
-    descr->f->argsort[PyArray_MERGESORT] = \
-        (PyArray_ArgSortFunc *)UNICODE_amergesort;
-    descr->f->argsort[PyArray_QUICKSORT] = \
-        (PyArray_ArgSortFunc *)UNICODE_aquicksort;
-    descr->f->argsort[PyArray_HEAPSORT] = \
-        (PyArray_ArgSortFunc *)UNICODE_aheapsort;
-    descr->f->sort[PyArray_QUICKSORT] = \
-        (PyArray_SortFunc *)UNICODE_quicksort;
-    descr->f->sort[PyArray_HEAPSORT] = \
-        (PyArray_SortFunc *)UNICODE_heapsort;
 }
 
 static struct PyMethodDef methods[] = {

Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py	2008-02-14 01:32:35 UTC (rev 4804)
+++ trunk/numpy/core/tests/test_multiarray.py	2008-02-14 07:24:13 UTC (rev 4805)
@@ -290,11 +290,11 @@
             c.sort(kind=kind)
             assert_equal(c, ai, msg)
 
-        # test string sorts. Only quick and heap sort are available.
+        # test string sorts.
         s = 'aaaaaaaa'
         a = np.array([s + chr(i) for i in range(100)])
         b = a[::-1].copy()
-        for kind in ['q', 'h'] :
+        for kind in ['q', 'm', 'h'] :
             msg = "string sort, kind=%s" % kind
             c = a.copy();
             c.sort(kind=kind)
@@ -303,11 +303,11 @@
             c.sort(kind=kind)
             assert_equal(c, a, msg)
 
-        # test unicode sort. Only quick and heap sort are available.
+        # test unicode sort.
         s = 'aaaaaaaa'
         a = np.array([s + chr(i) for i in range(100)], dtype=np.unicode)
         b = a[::-1].copy()
-        for kind in ['q', 'h'] :
+        for kind in ['q', 'm', 'h'] :
             msg = "unicode sort, kind=%s" % kind
             c = a.copy();
             c.sort(kind=kind)



More information about the Numpy-svn mailing list