[Numpy-discussion] How to set array values based on a condition?

Damian Eads eads@soe.ucsc....
Sun Mar 23 01:10:24 CDT 2008


Hi,

I am working on a memory-intensive experiment with very large arrays so 
I must be careful when allocating memory. Numpy already supports a 
number of in-place operations (+=, *=) making the task much more 
manageable. However, it is not obvious to me out I set values based on a 
very simple condition.

The expression

   y[y<0]=-1

generates a binary index mask y>=0 of the same size as the array y, 
which is problematic when y is quite large.

I was wondering if there was anything like a set_where(A, cmp, B, 
setval, [optional elseval]) function where cmp would be a comparison 
operator expressed as a string.

The code below illustrates what I want to do. Admittedly, it needs to be 
cleaned up but it's a proof of concept. Does numpy provide any functions 
that support the functionality of the code below?

Just a shot in the dark. Thanks!

Damian



import scipy
import scipy.weave
import types

_valid_cmps = ("==", "<=", ">=", "<", ">", "!=")
_array_type = type(scipy.array([]))

def set_where(x, cmp, cmpv, v, ev=None):
     """
     Sets every value in the array x to a specific value given a condition.
     It performs

          x[x cmp cmpv] = v

     efficiently where cmp can be any one of the strings

          "==", "<=", ">=", "<", ">", or "!="

     Examples:

        1. Sets x[i] to the value of -1 whenever x > 0.

          set_where(x, ">", 0, -1)

        2. Sets x[i] to the value of v[i] whenever x > 0.
           (x and v must be the same size)

          set_where(x, ">", 0, v)

        3. Sets x[i] to the value of v[i] whenever x[i] != y[i].
           (x, y and v must be the same size)

          set_where(x, "!=", y, v)

        3. Sets x[i] to the value of v[i] whenever x[i] != y[i].
           Otherwise sets x[i] = z[i].

           (x, y, v, and z must be the same size)

          set_where(x, "!=", y, v, z)

     """
     if cmp not in _valid_cmps:
         raise ValueError("%s is not one of the valid comparators (%s)" 
% (cmp, _valid_cmps))
     #endif
     vind = ''
     if type(v) == _array_type:
         vind = '[i]'

     cmpvind = ''
     if type(cmpv) == _array_type:
         cmpvind = '[i]'
     n = x.size
     i = 0
     vars = ['i', 'x', 'cmp', 'cmpv', 'v', 'n', 'ev']
     else_block = ""
     if ev is not None:
         evind = ""
         if type(ev) == _array_type:
             evind = "[i]"
         else_block = """
         else {
             x[i] = ev%s;
         }
         """ % evind
     else:
         ev = 0
     code = """
            for (i=0; i<=n;i++) {
               if (x[i] %s cmpv%s) {
                  x[i] = v%s;
               } %s
            }
            """ % (cmp, cmpvind, vind, else_block)
     print code

     scipy.weave.inline(code, vars)


More information about the Numpy-discussion mailing list