[Scipy-svn] r3436 - in trunk/scipy/sparse: . tests
scipy-svn@scip...
scipy-svn@scip...
Mon Oct 15 14:23:10 CDT 2007
Author: wnbell
Date: 2007-10-15 14:23:07 -0500 (Mon, 15 Oct 2007)
New Revision: 3436
Modified:
trunk/scipy/sparse/sparse.py
trunk/scipy/sparse/tests/test_sparse.py
Log:
fix sparse matvec result dimensions
check matvec input dimensions
resolves ticket #514
Modified: trunk/scipy/sparse/sparse.py
===================================================================
--- trunk/scipy/sparse/sparse.py 2007-10-13 18:33:22 UTC (rev 3435)
+++ trunk/scipy/sparse/sparse.py 2007-10-15 19:23:07 UTC (rev 3436)
@@ -673,23 +673,25 @@
def _matvec(self, other, fn):
if isdense(other):
- # This check is too harsh -- it prevents a column vector from
- # being created on-the-fly like dense matrix objects can.
- #if len(other) != self.shape[1]:
- # raise ValueError, "dimension mismatch"
- oth = numpy.ravel(other)
+ if other.size != self.shape[1] or \
+ (other.ndim == 2 and self.shape[1] != other.shape[0]):
+ raise ValueError, "dimension mismatch"
+
y = fn(self.shape[0], self.shape[1], \
- self.indptr, self.indices, self.data, oth)
+ self.indptr, self.indices, self.data, numpy.ravel(other))
+
if isinstance(other, matrix):
y = asmatrix(y)
+
if other.ndim == 2 and other.shape[1] == 1:
- # If 'other' was an (nx1) column vector, transpose the result
- # to obtain an (mx1) column vector.
- y = y.T
+ # If 'other' was an (nx1) column vector, reshape the result
+ y = y.reshape(-1,1)
+
return y
elif isspmatrix(other):
raise TypeError, "use matmat() for sparse * sparse"
+
else:
raise TypeError, "need a dense vector"
Modified: trunk/scipy/sparse/tests/test_sparse.py
===================================================================
--- trunk/scipy/sparse/tests/test_sparse.py 2007-10-13 18:33:22 UTC (rev 3435)
+++ trunk/scipy/sparse/tests/test_sparse.py 2007-10-15 19:23:07 UTC (rev 3436)
@@ -176,7 +176,23 @@
M = self.spmatrix(matrix([[3,0,0],[0,1,0],[2,0,3.0],[2,3,0]]))
col = matrix([1,2,3]).T
assert_array_almost_equal(M * col, M.todense() * col)
+
+ #check result dimensions (ticket #514)
+ assert_equal((M * array([1,2,3])).shape,(4,))
+ assert_equal((M * array([[1],[2],[3]])).shape,(4,1))
+ assert_equal((M * matrix([[1],[2],[3]])).shape,(4,1))
+ #ensure exception is raised for improper dimensions
+ bad_vecs = [array([1,2]), array([1,2,3,4]), array([[1],[2]]),
+ matrix([1,2,3]), matrix([[1],[2]])]
+ caught = 0
+ for x in bad_vecs:
+ try:
+ y = M * x
+ except ValueError:
+ caught += 1
+ assert_equal(caught,len(bad_vecs))
+
# Should this be supported or not?!
#flat = array([1,2,3])
#assert_array_almost_equal(M*flat, M.todense()*flat)
More information about the Scipy-svn
mailing list