[Numpy-discussion] How to set array values based on a condition?
Damian Eads
eads@soe.ucsc....
Sun Mar 23 01:10:24 CDT 2008
Hi,
I am working on a memory-intensive experiment with very large arrays so
I must be careful when allocating memory. Numpy already supports a
number of in-place operations (+=, *=) making the task much more
manageable. However, it is not obvious to me out I set values based on a
very simple condition.
The expression
y[y<0]=-1
generates a binary index mask y>=0 of the same size as the array y,
which is problematic when y is quite large.
I was wondering if there was anything like a set_where(A, cmp, B,
setval, [optional elseval]) function where cmp would be a comparison
operator expressed as a string.
The code below illustrates what I want to do. Admittedly, it needs to be
cleaned up but it's a proof of concept. Does numpy provide any functions
that support the functionality of the code below?
Just a shot in the dark. Thanks!
Damian
import scipy
import scipy.weave
import types
_valid_cmps = ("==", "<=", ">=", "<", ">", "!=")
_array_type = type(scipy.array([]))
def set_where(x, cmp, cmpv, v, ev=None):
"""
Sets every value in the array x to a specific value given a condition.
It performs
x[x cmp cmpv] = v
efficiently where cmp can be any one of the strings
"==", "<=", ">=", "<", ">", or "!="
Examples:
1. Sets x[i] to the value of -1 whenever x > 0.
set_where(x, ">", 0, -1)
2. Sets x[i] to the value of v[i] whenever x > 0.
(x and v must be the same size)
set_where(x, ">", 0, v)
3. Sets x[i] to the value of v[i] whenever x[i] != y[i].
(x, y and v must be the same size)
set_where(x, "!=", y, v)
3. Sets x[i] to the value of v[i] whenever x[i] != y[i].
Otherwise sets x[i] = z[i].
(x, y, v, and z must be the same size)
set_where(x, "!=", y, v, z)
"""
if cmp not in _valid_cmps:
raise ValueError("%s is not one of the valid comparators (%s)"
% (cmp, _valid_cmps))
#endif
vind = ''
if type(v) == _array_type:
vind = '[i]'
cmpvind = ''
if type(cmpv) == _array_type:
cmpvind = '[i]'
n = x.size
i = 0
vars = ['i', 'x', 'cmp', 'cmpv', 'v', 'n', 'ev']
else_block = ""
if ev is not None:
evind = ""
if type(ev) == _array_type:
evind = "[i]"
else_block = """
else {
x[i] = ev%s;
}
""" % evind
else:
ev = 0
code = """
for (i=0; i<=n;i++) {
if (x[i] %s cmpv%s) {
x[i] = v%s;
} %s
}
""" % (cmp, cmpvind, vind, else_block)
print code
scipy.weave.inline(code, vars)
