[Scipy-svn] r3301 - in trunk/scipy/sparse: . tests

scipy-svn@scip... scipy-svn@scip...
Tue Sep 4 17:08:25 CDT 2007


Author: wnbell
Date: 2007-09-04 17:08:22 -0500 (Tue, 04 Sep 2007)
New Revision: 3301

Modified:
   trunk/scipy/sparse/sparse.py
   trunk/scipy/sparse/tests/test_sparse.py
Log:
added support for inplace scalar multiplication and division



Modified: trunk/scipy/sparse/sparse.py
===================================================================
--- trunk/scipy/sparse/sparse.py	2007-09-02 20:21:59 UTC (rev 3300)
+++ trunk/scipy/sparse/sparse.py	2007-09-04 22:08:22 UTC (rev 3301)
@@ -271,10 +271,13 @@
 
     def __imul__(self, other):
         raise NotImplementedError
-
+    
     def __idiv__(self, other):
-        raise TypeError("No support for matrix division.")
+        return self.__itruediv__(other)
 
+    def __itruediv__(self, other):
+        raise NotImplementedError
+
     def __getattr__(self, attr):
         if attr == 'A':
             return self.toarray()
@@ -585,7 +588,7 @@
             # Convert this matrix to a dense matrix and add them
             return self.todense() - other
         else:
-            raise NotImplemented
+            raise NotImplementedError
 
 
     def __mul__(self, other): # self * other
@@ -608,6 +611,12 @@
                 tr = asarray(other).transpose()
             return self.transpose().dot(tr).transpose()
 
+    def __imul__(self, other): #self *= other
+        if isscalarlike(other):
+            self.data *= other
+            return self
+        else:
+            raise NotImplementedError
 
     def __neg__(self):
         return self._with_data(-self.data)
@@ -621,9 +630,16 @@
                 raise ValueError, "inconsistent shapes"
             return self._binopt(other,fn)
         else:
-            raise NotImplemented
+            raise NotImplementedError
+    
+    def __itruediv__(self, other): #self *= other
+        if isscalarlike(other):
+            recip = 1.0 / other
+            self.data *= recip
+            return self
+        else:
+            raise NotImplementedError
 
-
     def __pow__(self, other, fn):
         """ Element-by-element power (unless other is a scalar, in which
         case return the matrix power.)
@@ -633,7 +649,7 @@
         elif isspmatrix(other):
             return self._binopt(other,fn)
         else:
-            raise NotImplemented
+            raise NotImplementedError
 
 
     def _matmat(self, other, fn):
@@ -1826,6 +1842,16 @@
         else:
             return self.dot(other)
 
+    def __imul__(self, other):           # self * other
+        if isscalarlike(other):
+            # Multiply this scalar by every element.
+            for (key, val) in self.iteritems():
+                self[key] = val * other
+            #new.dtype.char = self.dtype.char
+            return self
+        else:
+            return NotImplementedError
+
     def __rmul__(self, other):          # other * self
         if isscalarlike(other):
             new = dok_matrix(self.shape, dtype=self.dtype)
@@ -1841,7 +1867,28 @@
             except AttributeError:
                 tr = asarray(other).transpose()
             return self.transpose().dot(tr).transpose()
+    
+    def __truediv__(self, other):           # self * other
+        if isscalarlike(other):
+            new = dok_matrix(self.shape, dtype=self.dtype)
+            # Multiply this scalar by every element.
+            for (key, val) in self.iteritems():
+                new[key] = val / other
+            #new.dtype.char = self.dtype.char
+            return new
+        else:
+            return self.tocsr() / other
 
+    
+    def __itruediv__(self, other):           # self * other
+        if isscalarlike(other):
+            # Multiply this scalar by every element.
+            for (key, val) in self.iteritems():
+                self[key] = val / other
+            return self
+        else:
+            return NotImplementedError
+
     # What should len(sparse) return? For consistency with dense matrices,
     # perhaps it should be the number of rows?  For now it returns the number
     # of non-zeros.
@@ -2259,8 +2306,15 @@
             self[:,:] = self * other
             return self
         else:
-            raise TypeError("In-place matrix multiplication not supported.")
+            raise NotImplementedError
 
+    def __itruediv__(self,other):
+        if isscalarlike(other):
+            self[:,:] = self / other
+            return self
+        else:
+            raise NotImplementedError
+
     # Whenever the dimensions change, empty lists should be created for each
     # row
 
@@ -2463,6 +2517,16 @@
         else:
             return self.dot(other)
 
+    def __truediv__(self, other):           # self / other
+        if isscalarlike(other):
+            new = self.copy()
+            # Divide every element by this scalar
+            new.data = numpy.array([[val/other for val in rowvals] for
+                                    rowvals in new.data], dtype=object)
+            return new
+        else:
+            return self.tocsr() / other
+
     def multiply(self, other):
         """Point-wise multiplication by another lil_matrix.
 

Modified: trunk/scipy/sparse/tests/test_sparse.py
===================================================================
--- trunk/scipy/sparse/tests/test_sparse.py	2007-09-02 20:21:59 UTC (rev 3300)
+++ trunk/scipy/sparse/tests/test_sparse.py	2007-09-04 22:08:22 UTC (rev 3301)
@@ -99,6 +99,24 @@
         assert_array_equal(self.dat*2,(self.datsp*2).todense())
         assert_array_equal(self.dat*17.3,(self.datsp*17.3).todense())
 
+    def check_imul_scalar(self):
+        a = self.datsp.copy()
+        a *= 2
+        assert_array_equal(self.dat*2,a.todense())
+
+        a = self.datsp.copy()
+        a *= 17.3
+        assert_array_equal(self.dat*17.3,a.todense())
+
+    def check_idiv_scalar(self):
+        a = self.datsp.copy()
+        a /= 2
+        assert_array_equal(self.dat/2,a.todense())
+
+        a = self.datsp.copy()
+        a /= 17.3
+        assert_array_equal(self.dat/17.3,a.todense())
+
     def check_rmul_scalar(self):
         assert_array_equal(2*self.dat,(2*self.datsp).todense())
         assert_array_equal(17.3*self.dat,(17.3*self.datsp).todense())



More information about the Scipy-svn mailing list