# [Numpy-svn] r3167 - trunk/numpy/core

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Sep 15 17:41:06 CDT 2006

```Author: oliphant
Date: 2006-09-15 17:40:54 -0500 (Fri, 15 Sep 2006)
New Revision: 3167

Modified:
trunk/numpy/core/numeric.py
Log:
Add rollaxis command and fix cross function

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2006-09-15 21:18:33 UTC (rev 3166)
+++ trunk/numpy/core/numeric.py	2006-09-15 22:40:54 UTC (rev 3167)
@@ -7,7 +7,7 @@
'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray',
'isfortran', 'empty_like', 'zeros_like',
'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot',
-           'alterdot', 'restoredot', 'cross', 'tensordot',
+           'alterdot', 'restoredot', 'rollaxis', 'cross', 'tensordot',
'array2string', 'get_printoptions', 'set_printoptions',
'array_repr', 'array_str', 'set_string_function',
'little_endian', 'require',
@@ -322,16 +322,38 @@
res = dot(at, bt)
return res.reshape(olda + oldb)

+def rollaxis(a, axis, start=0):
+    """Return transposed array so that axis is rolled before start.

-def _move_axis_to_0(a, axis):
-    if axis == 0:
-        return a
+    if a.shape is (3,4,5,6)
+    rollaxis(a, 3, 1).shape is (3,6,4,5)
+    rollaxis(a, 2, 0).shape is (5,3,4,6)
+    rollaxis(a, 1, 3).shape is (3,5,4,6)
+    rollaxis(a, 1, 4).shape is (3,5,6,4)
+    """
n = a.ndim
if axis < 0:
axis += n
-    axes = range(1, axis+1) + [0,] + range(axis+1, n)
+    if start < 0:
+        start += n
+    msg = 'rollaxis: %s (%d) must be >=0 and < %d'
+    if not (0 <= axis < n):
+        raise ValueError, msg % ('axis', axis, n)
+    if not (0 <= start < n+1):
+        raise ValueError, msg % ('start', start, n+1)
+    if (axis < start): # it's been removed
+        start -= 1
+    if axis==start:
+        return a
+    axes = range(0,n)
+    axes.remove(axis)
+    axes.insert(start, axis)
return a.transpose(axes)

+# fix hack in scipy which imports this function
+def _move_axis_to_0(a, axis):
+    return rollaxis(a, axis, 0)
+
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
"""Return the cross product of two (arrays of) vectors.

@@ -342,8 +364,8 @@
"""
if axis is not None:
axisa,axisb,axisc=(axis,)*3
-    a = _move_axis_to_0(asarray(a), axisa)
-    b = _move_axis_to_0(asarray(b), axisb)
+    a = asarray(a).swapaxes(axisa, 0)
+    b = asarray(b).swapaxes(axisb, 0)
msg = "incompatible dimensions for cross product\n"\
"(dimension must be 2 or 3)"
if (a.shape[0] not in [2,3]) or (b.shape[0] not in [2,3]):
@@ -354,7 +376,7 @@
if cp.ndim == 0:
return cp
else:
-                return cp.swapaxes(0,axisc)
+                return cp.swapaxes(0, axisc)
else:
x = a[1]*b[2]
y = -a[0]*b[2]

```