[Numpy-svn] r3167 - trunk/numpy/core

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Sep 15 17:41:06 CDT 2006


Author: oliphant
Date: 2006-09-15 17:40:54 -0500 (Fri, 15 Sep 2006)
New Revision: 3167

Modified:
   trunk/numpy/core/numeric.py
Log:
Add rollaxis command and fix cross function

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2006-09-15 21:18:33 UTC (rev 3166)
+++ trunk/numpy/core/numeric.py	2006-09-15 22:40:54 UTC (rev 3167)
@@ -7,7 +7,7 @@
            'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray',
            'isfortran', 'empty_like', 'zeros_like',
            'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot',
-           'alterdot', 'restoredot', 'cross', 'tensordot',
+           'alterdot', 'restoredot', 'rollaxis', 'cross', 'tensordot',
            'array2string', 'get_printoptions', 'set_printoptions',
            'array_repr', 'array_str', 'set_string_function',
            'little_endian', 'require',
@@ -322,16 +322,38 @@
     res = dot(at, bt)
     return res.reshape(olda + oldb)
 
+def rollaxis(a, axis, start=0):
+    """Return transposed array so that axis is rolled before start.
 
-def _move_axis_to_0(a, axis):
-    if axis == 0:
-        return a
+    if a.shape is (3,4,5,6)
+    rollaxis(a, 3, 1).shape is (3,6,4,5)
+    rollaxis(a, 2, 0).shape is (5,3,4,6)
+    rollaxis(a, 1, 3).shape is (3,5,4,6)
+    rollaxis(a, 1, 4).shape is (3,5,6,4)
+    """
     n = a.ndim
     if axis < 0:
         axis += n
-    axes = range(1, axis+1) + [0,] + range(axis+1, n)
+    if start < 0:
+        start += n
+    msg = 'rollaxis: %s (%d) must be >=0 and < %d'
+    if not (0 <= axis < n):
+        raise ValueError, msg % ('axis', axis, n)
+    if not (0 <= start < n+1):
+        raise ValueError, msg % ('start', start, n+1)
+    if (axis < start): # it's been removed 
+        start -= 1    
+    if axis==start:
+        return a
+    axes = range(0,n)
+    axes.remove(axis)
+    axes.insert(start, axis)
     return a.transpose(axes)
 
+# fix hack in scipy which imports this function
+def _move_axis_to_0(a, axis):
+    return rollaxis(a, axis, 0)
+
 def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
     """Return the cross product of two (arrays of) vectors.
 
@@ -342,8 +364,8 @@
     """
     if axis is not None:
         axisa,axisb,axisc=(axis,)*3
-    a = _move_axis_to_0(asarray(a), axisa)
-    b = _move_axis_to_0(asarray(b), axisb)
+    a = asarray(a).swapaxes(axisa, 0)
+    b = asarray(b).swapaxes(axisb, 0)
     msg = "incompatible dimensions for cross product\n"\
           "(dimension must be 2 or 3)"
     if (a.shape[0] not in [2,3]) or (b.shape[0] not in [2,3]):
@@ -354,7 +376,7 @@
             if cp.ndim == 0:
                 return cp
             else:
-                return cp.swapaxes(0,axisc)
+                return cp.swapaxes(0, axisc)
         else:
             x = a[1]*b[2]
             y = -a[0]*b[2]



More information about the Numpy-svn mailing list