[Numpy-discussion] apply_along, apply_over w/ subarrays
Pierre GM
pgmdevlist at gmail.com
Sun Nov 26 16:11:24 CST 2006
Folks,
Is there any reason why `apply_along_axis`, `apply_over_axes` (and I expect
`vectorize` as well, but I haven't tried) won't accept subclasses of
ndarrays ?
Would it be possible to use `asanyarray` instead of `asarray` in those
functions ?
Oh, and I ran into some problems with apply_along-axis: in the uncommon case
where the output of the 1d functions is a scalar array (eg., N.array(0)) (or
a masked_singleton), the current version raises a TypeError, as the output is
not a scalar, but doesn't have a length either. I tried to correct that, but
I'd prefer to get some feedback before submitting the patch.
Thanks for all.
P.
#################################################
--- numpy/lib/shape_base.py.init 2006-11-26 13:22:57.000000000 -0500
+++ numpy/lib/shape_base.py 2006-11-26 17:01:22.000000000 -0500
@@ -13,7 +13,7 @@ def apply_along_axis(func1d,axis,arr,*ar
and arr is an N-d array. i varies so as to apply the function
along the given axis for each 1-d subarray in arr.
"""
- arr = asarray(arr)
+ arr = asanyarray(arr)
nd = arr.ndim
if axis < 0:
axis += nd
@@ -28,9 +28,16 @@ def apply_along_axis(func1d,axis,arr,*ar
outshape = asarray(arr.shape).take(indlist)
i.put(indlist, ind)
res = func1d(arr[tuple(i.tolist())],*args)
- # if res is a number, then we have a smaller output array
- if isscalar(res):
- outarr = zeros(outshape,asarray(res).dtype)
+ # if res is a number, or doesn't have a length , then we have a smaller
output array
+ asscalar = isscalar(res)
+ if not asscalar:
+ try:
+ len(res)
+ except TypeError:
+ asscalar = True
+ #
+ if asscalar:
+ outarr = zeros(outshape,asarray(res).dtype).view(res.__class__)
outarr[ind] = res
Ntot = product(outshape)
k = 1
@@ -52,7 +59,7 @@ def apply_along_axis(func1d,axis,arr,*ar
holdshape = outshape
outshape = list(arr.shape)
outshape[axis] = len(res)
- outarr = zeros(outshape,asarray(res).dtype)
+ outarr = zeros(outshape,asarray(res).dtype).view(res.__class__)
outarr[tuple(i.tolist())] = res
k = 1
while k < Ntot:
@@ -78,7 +85,7 @@ def apply_over_axes(func, a, axes):
to be either the same shape as a or have one less dimension.
This call is repeated for each axis in the axes sequence.
"""
- val = asarray(a)
+ val = asanyarray(a)
N = a.ndim
if array(axes).ndim == 0:
axes = (axes,)
#################################################
