[Numpy-svn] r3325 - in trunk/numpy: core linalg

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Oct 13 13:18:21 CDT 2006


Author: oliphant
Date: 2006-10-13 13:18:18 -0500 (Fri, 13 Oct 2006)
New Revision: 3325

Modified:
   trunk/numpy/core/numeric.py
   trunk/numpy/linalg/linalg.py
Log:
Fix-up tensor solve and tensor inv and rename to match tensordot.

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2006-10-13 17:18:28 UTC (rev 3324)
+++ trunk/numpy/core/numeric.py	2006-10-13 18:18:18 UTC (rev 3325)
@@ -252,7 +252,7 @@
         pass
 
 
-def tensordot(a, b, axes=[-1,0]):
+def tensordot(a, b, axes=2):
     """tensordot returns the product for any (ndim >= 1) arrays.
 
     r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where
@@ -265,10 +265,19 @@
     When there is more than one axis to sum over, the corresponding
     arguments to axes should be sequences of the same length with the first
     axis to sum over given first in both sequences, the second axis second,
-    and so forth. 
+    and so forth.
+
+    If the axes argument is an integer, N, then the last N dimensions of a
+    and first N dimensions of b are summed over. 
     """
-    axes_a, axes_b = axes
     try:
+        iter(axes)
+    except:
+        axes_a = range(-axes,0)
+        axes_b = range(0,axes)
+    else:
+        axes_a, axes_b = axes
+    try:
         na = len(axes_a)
         axes_a = list(axes_a)
     except TypeError:

Modified: trunk/numpy/linalg/linalg.py
===================================================================
--- trunk/numpy/linalg/linalg.py	2006-10-13 17:18:28 UTC (rev 3324)
+++ trunk/numpy/linalg/linalg.py	2006-10-13 18:18:18 UTC (rev 3325)
@@ -6,7 +6,7 @@
 zgeev, dgesdd, zgesdd, dgelsd, zgelsd, dsyevd, zheevd, dgetrf, dpotrf.
 """
 
-__all__ = ['solve', 'solvetensor', 'invtensor',
+__all__ = ['solve', 'tensorsolve', 'tensorinv',
            'inv', 'cholesky',
            'eigvals',
            'eigvalsh', 'pinv',
@@ -122,10 +122,11 @@
 
 # Linear equations
 
-def solvetensor(a, b, axes=None):
+def tensorsolve(a, b, axes=None):
     """Solves the tensor equation a x = b for x
 
-    where it is assumed that all the indices of x are summed over in the product.
+    where it is assumed that all the indices of x are summed over in
+    the product.
 
     a can be N-dimensional.  x will have the dimensions of A subtracted from
     the dimensions of b.
@@ -181,40 +182,48 @@
         return b.transpose().astype(result_t)
 
 
-def invtensor(a, ind=2):
-    """Find the inverse tensor.
+def tensorinv(a, ind=2):
+    """Find the 'inverse' of a N-d array
 
-    ind > 0 ==> first (ind) indices of a are summed over
-    ind < 0 ==> last (-ind) indices of a are summed over
+    ind must be a positive integer specifying
+    how many indices at the front of the array are involved
+    in the inverse sum.
+    
+    the result is ainv with shape a.shape[ind:] + a.shape[:ind]
 
-    if ind is a list, then it specifies the summed over axes
+    tensordot(ainv, a, ind) is an identity operator
 
-    When the inv tensor and the tensor are summed over the
-    indicated axes a separable identity tensor remains.
+    and so is
 
-    The inverse has the summed over axes at the end.
+    tensordot(a, ainv, a.shape-newind)
+
+    Example:
+
+       a = rand(4,6,8,3)
+       ainv = tensorinv(a)
+       # ainv.shape is (8,3,4,6)
+       # suppose b has shape (4,6)
+       tensordot(ainv, b) # produces same (8,3)-shaped output as
+       tensorsolve(a, b)
+
+       a = rand(24,8,3)
+       ainv = tensorinv(a,1)
+       # ainv.shape is (8,3,24)
+       # suppose b has shape (24,)
+       tensordot(ainv, b, 1)  # produces the same (8,3)-shaped output as
+       tensorsolve(a, b)
+
     """
-
     a = asarray(a)
     oldshape = a.shape
     prod = 1
-    if iterable(ind):
-        invshape = range(a.ndim)
-        for axis in ind:
-            invshape.remove(axis)
-            invshape.insert(a.ndim,axis)
-            prod *= oldshape[axis]
-    elif ind > 0:
+    if ind > 0:
         invshape = oldshape[ind:] + oldshape[:ind]
-        for k in oldshape[:ind]:
+        for k in oldshape[ind:]:
             prod *= k
-    elif ind < 0:
-        invshape = oldshape[:-ind] + oldshape[-ind:]
-        for k in oldshape[-ind:]:
-            prod *= k
     else:
         raise ValueError, "Invalid ind argument."
-    a = a.reshape(-1,prod)
+    a = a.reshape(prod,-1)
     ia = inv(a)
     return ia.reshape(*invshape)
 



More information about the Numpy-svn mailing list