[Numpy-discussion] A faster median (Wirth's method)

Sturla Molden sturla@molden...
Mon Aug 31 23:06:50 CDT 2009


We recently has a discussion regarding an optimization of NumPy's median 
to average O(n) complexity. After some searching, I found out there is a 
selection algorithm competitive in speed with Hoare's quick select. It 
has the advantage of being a lot simpler to implement. In plain Python:

import numpy as np

def wirthselect(array, k):
   
    """ Niklaus Wirth's selection algortithm """

    a = np.ascontiguousarray(array)
    if (a is array): a = a.copy()

    l = 0
    m = a.shape[0] - 1
    while l < m:
        x = a[k]
        i = l
        j = m
        while 1:
            while a[i] < x: i += 1
            while x < a[j]: j -= 1
            if i <= j:
                tmp = a[i]
                a[i] = a[j]
                a[j] = tmp
                i += 1
                j -= 1
            if i > j: break
        if j < k: l = i
        if k < i: m = j

    return a


Now, the median can be obtained in average O(n) time as:


def median(x):

    """ median in average O(n) time """

    n = x.shape[0]
    k = n >> 1
    s = wirthselect(x, k)
    if n & 1:
        return s[k]
    else:
        return 0.5*(s[k]+s[:k].max())


The beauty of this is that Wirth select is extremely easy to migrate to 
Cython:


import numpy
ctypedef numpy.double_t T # or whatever

def wirthselect(numpy.ndarray[T, ndim=1] array, int k):
   
    cdef int i, j, l, m
    cdef T x, tmp
    cdef T *a

    _array = np.ascontiguousarray(array)
    if (_array is array): _array = _array.copy()
    a = <T *> _array.data

    l = 0
    m = <int> a.shape[0] - 1
    with nogil:
        while l < m:
            x = a[k]
            i = l
            j = m
            while 1:
                while a[i] < x: i += 1
                while x < a[j]: j -= 1
                if i <= j:
                    tmp = a[i]
                    a[i] = a[j]
                    a[j] = tmp
                    i += 1
                    j -= 1
                if i > j: break
            if j < k: l = i
            if k < i: m = j

    return _array


For example, we could have a small script that generates withselect for 
all NumPy dtypes (T as template), and use a dict as jump table.

Chad, you can continue to write quick select using NumPy's C quick sort 
in numpy/core/src/_sortmodule.c.src.  When you are done, it might be 
about 10% faster than this. :-)


Reference:
http://ndevilla.free.fr/median/median.pdf



Best regards,
Sturla Molden









More information about the NumPy-Discussion mailing list