[Numpy-svn] r3090 - in trunk/numpy: . core core/src

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Aug 28 21:47:45 CDT 2006


Author: oliphant
Date: 2006-08-28 21:47:39 -0500 (Mon, 28 Aug 2006)
New Revision: 3090

Modified:
   trunk/numpy/__init__.py
   trunk/numpy/core/numeric.py
   trunk/numpy/core/src/arrayobject.c
Log:
Add float, int, etc. to numpy name-space.  Flesh out tensordot.  Fix-up getcharbuf to allow all 8-bit types to be returned as a charbuf.

Modified: trunk/numpy/__init__.py
===================================================================
--- trunk/numpy/__init__.py	2006-08-28 20:56:23 UTC (rev 3089)
+++ trunk/numpy/__init__.py	2006-08-29 02:47:39 UTC (rev 3090)
@@ -36,6 +36,10 @@
     from core import *
     import lib
     from lib import *
+    # Make these accessible from numpy name-space
+    #  but not imported in from numpy import *
+    from __builtin__ import bool, int, long, float, complex, \
+         object, unicode, str
     import linalg
     import fft
     import random

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2006-08-28 20:56:23 UTC (rev 3089)
+++ trunk/numpy/core/numeric.py	2006-08-29 02:47:39 UTC (rev 3090)
@@ -190,7 +190,7 @@
     is a sequence of indices into a.  This sequence must be
     converted to a tuple in order to be used to index into a.
     """
-    return transpose(a.nonzero())
+    return asarray(a.nonzero()).T
 
 def flatnonzero(a):
     """Return indicies that are not-zero in flattened version of a
@@ -252,6 +252,7 @@
     def restoredot():
         pass
 
+
 def tensordot(a, b, axes=(-1,0)):
     """tensordot returns the product for any (ndim >= 1) arrays.
 
@@ -259,37 +260,67 @@
 
     the axes to be summed over are given by the axes argument.
     the first element of the sequence determines the axis or axes
-    in arr1 to sum over and the second element in axes argument sequence
+    in arr1 to sum over, and the second element in axes argument sequence
+    determines the axis or axes in arr2 to sum over. 
+
+    When there is more than one axis to sum over, the corresponding
+    arguments to axes should be sequences of the same length with the first
+    axis to sum over given first in both sequences, the second axis second,
+    and so forth. 
     """
     axes_a, axes_b = axes
     try:
         na = len(axes_a)
+        axes_a = list(axes_a)
     except TypeError:
         axes_a = [axes_a]
         na = 1
     try:
         nb = len(axes_b)
+        axes_b = list(axes_b)
     except TypeError:
         axes_b = [axes_b]
         nb = 1
 
     a, b = asarray(a), asarray(b)
     as = a.shape
+    nda = len(a.shape)
     bs = b.shape
+    ndb = len(b.shape)
     equal = 1
     if (na != nb): equal = 0
-    for k in xrange(na):
-        if as[axes_a[k]] != bs[axes_b[k]]:
-            equal = 0
-            break
-
+    else:
+        for k in xrange(na):
+            if as[axes_a[k]] != bs[axes_b[k]]:
+                equal = 0
+                break
+            if axes_a[k] < 0:
+                axes_a[k] += nda
+            if axes_b[k] < 0:
+                axes_b[k] += ndb
     if not equal:
-        raise ValueError, "shape-mismatch for sum"    
-    
-    olda = [k for k in aa if k not in axes_a]
-    oldb = [k for k in bs if k not in axes_b]
+        raise ValueError, "shape-mismatch for sum"
 
-    at = a.reshape(nd1, nd2)
+    # Move the axes to sum over to the end of "a"
+    # and to the front of "b"
+    notin = [k for k in range(nda) if k not in axes_a]
+    newaxes_a = notin + axes_a
+    N2 = 1
+    for axis in axes_a:
+        N2 *= as[axis]
+    newshape_a = (-1, N2)
+    olda = [as[axis] for axis in notin]
+
+    notin = [k for k in range(ndb) if k not in axes_b]
+    newaxes_b = axes_b + notin
+    N2 = 1
+    for axis in axes_b:
+        N2 *= bs[axis]
+    newshape_b = (N2, -1)
+    oldb = [bs[axis] for axis in notin]
+
+    at = a.transpose(newaxes_a).reshape(newshape_a)
+    bt = b.transpose(newaxes_b).reshape(newshape_b)
     res = dot(at, bt)
     return res.reshape(olda + oldb)
 

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2006-08-28 20:56:23 UTC (rev 3089)
+++ trunk/numpy/core/src/arrayobject.c	2006-08-29 02:47:39 UTC (rev 3090)
@@ -3038,12 +3038,13 @@
 array_getcharbuf(PyArrayObject *self, Py_ssize_t segment, constchar **ptrptr)
 {
         if (self->descr->type_num == PyArray_STRING || \
-            self->descr->type_num == PyArray_UNICODE)
+            self->descr->type_num == PyArray_UNICODE || \
+            self->descr->elsize == 1)
                 return array_getreadbuf(self, segment, (void **) ptrptr);
         else {
                 PyErr_SetString(PyExc_TypeError,
-                                "non-character array cannot be interpreted "\
-                                "as character buffer");
+                                "non-character (or 8-bit) array cannot be "\
+                                "interpreted as character buffer");
                 return -1;
         }
 }



More information about the Numpy-svn mailing list