[Scipy-svn] r3143 - in trunk/Lib/cluster: . tests

scipy-svn@scip... scipy-svn@scip...
Tue Jul 3 06:29:07 CDT 2007


Author: cdavid
Date: 2007-07-03 06:29:01 -0500 (Tue, 03 Jul 2007)
New Revision: 3143

Modified:
   trunk/Lib/cluster/tests/test_vq.py
   trunk/Lib/cluster/vq.py
Log:
Add an option to kmeans2 to decide what to do when one cluster disappears

Modified: trunk/Lib/cluster/tests/test_vq.py
===================================================================
--- trunk/Lib/cluster/tests/test_vq.py	2007-07-02 15:25:56 UTC (rev 3142)
+++ trunk/Lib/cluster/tests/test_vq.py	2007-07-03 11:29:01 UTC (rev 3143)
@@ -1,7 +1,7 @@
 #! /usr/bin/env python
 
 # David Cournapeau
-# Last Change: Tue Jun 19 10:00 PM 2007 J
+# Last Change: Tue Jul 03 08:00 PM 2007 J
 
 # For now, just copy the tests from sandbox.pyem, so we can check that
 # kmeans works OK for trivial examples.
@@ -12,7 +12,7 @@
 import numpy as N
 
 set_package_path()
-from cluster.vq import kmeans, kmeans2, py_vq, py_vq2, _py_vq_1d, vq
+from cluster.vq import kmeans, kmeans2, py_vq, py_vq2, _py_vq_1d, vq, ClusterError
 try:
     from cluster import _vq
     TESTC=True
@@ -21,10 +21,10 @@
     TESTC=False
 restore_path()
 
+import os.path
 #Optional:
 set_local_path()
 # import modules that are located in the same directory as this file.
-import os.path
 DATAFILE1 = os.path.join(sys.path[0], "data.txt")
 restore_path()
 
@@ -106,6 +106,12 @@
                          [-2.31149087,-0.05160469]])
 
         res = kmeans(data, initk)
+        res = kmeans2(data, initk, missing = 'warn')
+        try :
+            res = kmeans2(data, initk, missing = 'raise')
+            raise AssertionError("Exception not raised ! Should not happen")
+        except ClusterError, e:
+            print "exception raised as expected: " + str(e)
 
     def check_kmeans2_simple(self, level=1):
         """Testing simple call to kmeans2 and its results."""

Modified: trunk/Lib/cluster/vq.py
===================================================================
--- trunk/Lib/cluster/vq.py	2007-07-02 15:25:56 UTC (rev 3142)
+++ trunk/Lib/cluster/vq.py	2007-07-03 11:29:01 UTC (rev 3143)
@@ -35,6 +35,9 @@
      std, mean
 import numpy as N
 
+class ClusterError(Exception):
+    pass
+
 def whiten(obs):
     """ Normalize a group of observations on a per feature basis.
 
@@ -188,7 +191,8 @@
     else:
         (n, d) = shape(obs)
 
-    # code books and observations should have same number of features and same shape
+    # code books and observations should have same number of features and same
+    # shape
     if not N.ndim(obs) == N.ndim(code_book):
         raise ValueError("Observation and code_book should have the same rank")
     elif not d == code_book.shape[1]:
@@ -228,7 +232,7 @@
     nc = code_book.size
     dist = N.zeros((n, nc))
     for i in range(nc):
-        dist[:,i] = N.sum(obs - code_book[i])
+        dist[:, i] = N.sum(obs - code_book[i])
     print dist
     code = argmin(dist)
     min_dist = dist[code]
@@ -270,7 +274,7 @@
             code book(%d) and obs(%d) should have the same
             number of features (eg columns)""" % (code_book.shape[1], d))
 
-    diff = obs[newaxis,:,:] - code_book[:,newaxis,:]
+    diff = obs[newaxis, :, :] - code_book[:,newaxis,:]
     dist = sqrt(N.sum(diff * diff, -1))
     code = argmin(dist, 0)
     min_dist = minimum.reduce(dist, 0) #the next line I think is equivalent
@@ -314,7 +318,7 @@
     """
 
     code_book = array(guess, copy = True)
-    Nc = code_book.shape[0]
+    nc = code_book.shape[0]
     avg_dist = []
     diff = thresh+1.
     while diff > thresh:
@@ -324,7 +328,7 @@
         #recalc code_book as centroids of associated obs
         if(diff > thresh):
             has_members = []
-            for i in arange(Nc):
+            for i in arange(nc):
                 cell_members = compress(equal(obs_code, i), obs, 0)
                 if cell_members.shape[0] > 0:
                     code_book[i] = mean(cell_members, 0)
@@ -468,7 +472,20 @@
 
 _valid_init_meth = {'random': _krandinit, 'points': _kpoints}
 
-def kmeans2(data, k, iter = 10, thresh = 1e-5, minit='random'):
+def _missing_warn():
+    """Print a warning when called."""
+    warnings.warn("One of the clusters is empty. "
+                 "Re-run kmean with a different initialization.")
+
+def _missing_raise():
+    """raise a ClusterError when called."""
+    raise ClusterError, "One of the clusters is empty. "\
+                        "Re-run kmean with a different initialization."
+
+_valid_miss_meth = {'warn': _missing_warn, 'raise': _missing_raise}
+
+def kmeans2(data, k, iter = 10, thresh = 1e-5, minit = 'random',
+        missing = 'warn'):
     """Classify a set of points into k clusters using kmean algorithm.
 
     The algorithm works by minimizing the euclidian distance between data points
@@ -510,6 +527,8 @@
             cluster[label[i]].
 
     """
+    if missing not in _valid_miss_meth.keys():
+        raise ValueError("Unkown missing method: %s" % str(missing))
     # If data is rank 1, then we have 1 dimension problem.
     nd  = N.ndim(data)
     if nd == 1:
@@ -544,9 +563,9 @@
         clusters = init(data, k)
 
     assert not iter == 0
-    return _kmeans2(data, clusters, iter, nc)
+    return _kmeans2(data, clusters, iter, nc, _valid_miss_meth[missing])
 
-def _kmeans2(data, code, niter, nc):
+def _kmeans2(data, code, niter, nc, missing):
     """ "raw" version of kmeans2. Do not use directly.
 
     Run kmeans with a given initial codebook.  """
@@ -560,8 +579,7 @@
             if mbs[0].size > 0:
                 code[j] = N.mean(data[mbs], axis=0)
             else:
-                warnings.warn("One of the clusters are empty. " \
-                              "Re-run kmean with a different initialization.")
+                missing()
 
     return code, label
 



More information about the Scipy-svn mailing list