# [Scipy-svn] r6274 - in trunk/scipy/linalg: . tests

scipy-svn@scip... scipy-svn@scip...
Fri Mar 26 19:17:23 CDT 2010

```Author: warren.weckesser
Date: 2010-03-26 19:17:22 -0500 (Fri, 26 Mar 2010)
New Revision: 6274

Modified:
trunk/scipy/linalg/basic.py
trunk/scipy/linalg/tests/test_basic.py
Log:
ENH: Allow linalg.block_diag to accept scalar and 1D arguments (ticket #1128)

Modified: trunk/scipy/linalg/basic.py
===================================================================
--- trunk/scipy/linalg/basic.py	2010-03-26 05:35:35 UTC (rev 6273)
+++ trunk/scipy/linalg/basic.py	2010-03-27 00:17:22 UTC (rev 6274)
@@ -18,7 +18,7 @@
from numpy import asarray, zeros, sum, greater_equal, subtract, arange,\
conjugate, dot, transpose
import numpy
-from numpy import asarray_chkfinite, outer, concatenate, reshape, single
+from numpy import asarray_chkfinite, atleast_2d, outer, concatenate, reshape, single
from numpy import matrix as Matrix
from numpy.linalg import LinAlgError
from scipy.linalg import calc_lwork
@@ -894,7 +894,7 @@
return concatenate(concatenate(o, axis=1), axis=1)

def block_diag(*arrs):
-    """Create a diagonal matrix from the provided arrays.
+    """Create a block diagonal matrix from the provided arrays.

Given the inputs `A`, `B` and `C`, the output will have these
arrays arranged on the diagonal::
@@ -908,8 +908,9 @@

Parameters
----------
-    A, B, C, ... : 2-D ndarray
-        Input arrays.
+    A, B, C, ... : array-like, up to 2D
+        Input arrays.  A 1D array or array-like sequence with length n is
+        treated as a 2D array with shape (1,n).

Returns
-------
@@ -929,15 +930,28 @@
>>> B = [[3, 4, 5],
...      [6, 7, 8]]
>>> C = [[7]]
-    >>> print block_diag(A, B, C)
-    [[ 1.  0.  0.  0.  0.  0.]
-     [ 0.  1.  0.  0.  0.  0.]
-     [ 0.  0.  3.  4.  5.  0.]
-     [ 0.  0.  6.  7.  8.  0.]
-     [ 0.  0.  0.  0.  0.  7.]]
+    >>> print(block_diag(A, B, C))
+    [[1 0 0 0 0 0]
+     [0 1 0 0 0 0]
+     [0 0 3 4 5 0]
+     [0 0 6 7 8 0]
+     [0 0 0 0 0 7]]
+    >>> block_diag(1.0, [2, 3], [[4, 5], [6, 7]])
+    array([[ 1.,  0.,  0.,  0.,  0.],
+           [ 0.,  2.,  3.,  0.,  0.],
+           [ 0.,  0.,  0.,  4.,  5.],
+           [ 0.,  0.,  0.,  6.,  7.]])

"""
-    arrs = [asarray(a) for a in arrs]
+    if arrs == ():
+        arrs = ([],)
+    arrs = [atleast_2d(a) for a in arrs]
+
+    bad_args = [k for k in range(len(arrs)) if arrs[k].ndim > 2]
+        raise ValueError("arguments in the following positions have dimension "
+                            "greater than 2: %s" % bad_args)
+
shapes = numpy.array([a.shape for a in arrs])
out = zeros(sum(shapes, axis=0), dtype=arrs[0].dtype)

@@ -947,4 +961,3 @@
r += rr
c += cc
return out
-

Modified: trunk/scipy/linalg/tests/test_basic.py
===================================================================
--- trunk/scipy/linalg/tests/test_basic.py	2010-03-26 05:35:35 UTC (rev 6273)
+++ trunk/scipy/linalg/tests/test_basic.py	2010-03-27 00:17:22 UTC (rev 6274)
@@ -463,7 +463,24 @@

x = block_diag([[True]])
assert_equal(x.dtype, bool)
+
+    def test_scalar_and_1d_args(self):
+        a = block_diag(1)
+        assert_equal(a.shape, (1,1))
+        assert_array_equal(a, [[1]])
+
+        a = block_diag([2,3], 4)
+        assert_array_equal(a, [[2, 3, 0], [0, 0, 4]])

+        assert_raises(ValueError, block_diag, [[[1]]])
+
+    def test_no_args(self):
+        a = block_diag()
+        assert_equal(a.ndim, 2)
+        assert_equal(a.nbytes, 0)
+
+
class TestPinv(TestCase):

def test_simple(self):

```