[Numpy-svn] r4991 - trunk/numpy/core

numpy-svn@scip... numpy-svn@scip...
Tue Apr 8 19:16:18 CDT 2008


Author: oliphant
Date: 2008-04-08 19:16:09 -0500 (Tue, 08 Apr 2008)
New Revision: 4991

Modified:
   trunk/numpy/core/numeric.py
Log:
Improve empty_like and zeros_like to respect sub-type.

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2008-04-08 20:43:07 UTC (rev 4990)
+++ trunk/numpy/core/numeric.py	2008-04-09 00:16:09 UTC (rev 4991)
@@ -37,26 +37,33 @@
 ALLOW_THREADS = multiarray.ALLOW_THREADS
 BUFSIZE = multiarray.BUFSIZE
 
+ndarray = multiarray.ndarray
+flatiter = multiarray.flatiter
+broadcast = multiarray.broadcast
+dtype = multiarray.dtype
+ufunc = type(sin)
 
-# from Fernando Perez's IPython
+
+# originally from Fernando Perez's IPython
 def zeros_like(a):
     """Return an array of zeros of the shape and data-type of a.
 
     If you don't explicitly need the array to be zeroed, you should instead
-    use empty_like(), which is faster as it only allocates memory.
+    use empty_like(), which is a bit faster as it only allocates memory.
     """
+    if isinstance(a, ndarray):
+        res = ndarray.__new__(type(a), a.shape, a.dtype, order=a.flags.fnc)
+        res.fill(0)
+        return res
     try:
-        return zeros(a.shape, a.dtype, a.flags.fnc)
+        wrap = a.__array_wrap__
     except AttributeError:
-        try:
-            wrap = a.__array_wrap__
-        except AttributeError:
-            wrap = None
-        a = asarray(a)
-        res = zeros(a.shape, a.dtype)
-        if wrap:
-            res = wrap(res)
-        return res
+        wrap = None
+    a = asarray(a)
+    res = zeros(a.shape, a.dtype)
+    if wrap:
+        res = wrap(res)
+    return res
 
 def empty_like(a):
     """Return an empty (uninitialized) array of the shape and data-type of a.
@@ -65,18 +72,18 @@
     your array to be initialized, you should use zeros_like().
 
     """
+    if isinstance(a, ndarray):
+        res = ndarray.__new__(type(a), a.shape, a.dtype, order=a.flags.fnc)
+        return res
     try:
-        return empty(a.shape, a.dtype, a.flags.fnc)
+        wrap = a.__array_wrap__
     except AttributeError:
-        try:
-            wrap = a.__array_wrap__
-        except AttributeError:
-            wrap = None
-        a = asarray(a)
-        res = empty(a.shape, a.dtype)
-        if wrap:
-            res = wrap(res)
-        return res
+        wrap = None
+    a = asarray(a)
+    res = empty(a.shape, a.dtype)
+    if wrap:
+        res = wrap(res)
+    return res
 
 # end Fernando's utilities
 
@@ -98,11 +105,6 @@
 
 newaxis = None
 
-ndarray = multiarray.ndarray
-flatiter = multiarray.flatiter
-broadcast = multiarray.broadcast
-dtype = multiarray.dtype
-ufunc = type(sin)
 
 arange = multiarray.arange
 array = multiarray.array



More information about the Numpy-svn mailing list