[Scipy-svn] r3127 - in trunk/Lib/sandbox/pyem: . tests

scipy-svn@scip... scipy-svn@scip...
Sun Jul 1 04:52:13 CDT 2007


Author: cdavid
Date: 2007-07-01 04:52:06 -0500 (Sun, 01 Jul 2007)
New Revision: 3127

Modified:
   trunk/Lib/sandbox/pyem/gmm_em.py
   trunk/Lib/sandbox/pyem/tests/test_gmm_em.py
Log:
Add support for EM in log domain + tests

Modified: trunk/Lib/sandbox/pyem/gmm_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/gmm_em.py	2007-07-01 09:32:00 UTC (rev 3126)
+++ trunk/Lib/sandbox/pyem/gmm_em.py	2007-07-01 09:52:06 UTC (rev 3127)
@@ -1,5 +1,5 @@
 # /usr/bin/python
-# Last Change: Sun Jul 01 05:00 PM 2007 J
+# Last Change: Sun Jul 01 06:00 PM 2007 J
 
 """Module implementing GMM, a class to estimate Gaussian mixture models using
 EM, and EM, a class which use GMM instances to estimate models parameters using
@@ -331,7 +331,7 @@
     def __init__(self):
         pass
     
-    def train(self, data, model, maxiter = 10, thresh = 1e-5):
+    def train(self, data, model, maxiter = 10, thresh = 1e-5, log = False):
         """Train a model using EM.
 
         Train a model using data, and stops when the likelihood increase
@@ -366,7 +366,10 @@
         model.init(data)
 
         # Actual training
-        like = self._train_simple_em(data, model, maxiter, thresh)
+        if log:
+            like = self._train_simple_em_log(data, model, maxiter, thresh)
+        else:
+            like = self._train_simple_em(data, model, maxiter, thresh)
         return like
     
     def _train_simple_em(self, data, model, maxiter, thresh):
@@ -385,6 +388,21 @@
             if has_em_converged(like[i], like[i-1], thresh):
                 return like[0:i]
 
+    def _train_simple_em_log(self, data, model, maxiter, thresh):
+        # Likelihood is kept
+        like    = N.zeros(maxiter)
+
+        # Em computation, with computation of the likelihood
+        g, tgd  = model.compute_log_responsabilities(data)
+        like[0] = N.sum(densities.logsumexp(tgd), axis = 0)
+        model.update_em(data, N.exp(g))
+        for i in range(1, maxiter):
+            g, tgd  = model.compute_log_responsabilities(data)
+            like[i] = N.sum(densities.logsumexp(tgd), axis = 0)
+            model.update_em(data, N.exp(g))
+            if has_em_converged(like[i], like[i-1], thresh):
+                return like[0:i]
+
 class RegularizedEM:
     # TODO: separate regularizer from EM class ?
     def __init__(self, pcnt = _PRIOR_COUNT, pval = _COV_PRIOR):

Modified: trunk/Lib/sandbox/pyem/tests/test_gmm_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/tests/test_gmm_em.py	2007-07-01 09:32:00 UTC (rev 3126)
+++ trunk/Lib/sandbox/pyem/tests/test_gmm_em.py	2007-07-01 09:52:06 UTC (rev 3127)
@@ -1,5 +1,5 @@
 #! /usr/bin/env python
-# Last Change: Wed Jun 13 07:00 PM 2007 J
+# Last Change: Sun Jul 01 06:00 PM 2007 J
 
 # For now, just test that all mode/dim execute correctly
 
@@ -110,65 +110,55 @@
 class test_datasets(EmTest):
     """This class tests whether the EM algorithms works using pre-computed
     datasets."""
-    def test_1d_full(self, level = 1):
-        d = 1
-        k = 4
-        mode = 'full'
-        # Data are exactly the same than in diagonal mode, just test that
-        # calling full mode works even in 1d, even if it is kind of stupid to
-        # do so
-        dic = load_dataset('diag_1d_4k.mat')
+    def _test(self, dataset, log):
+        dic = load_dataset(dataset)
 
         gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
         gmm = GMM(gm, 'test')
-        EM().train(dic['data'], gmm)
+        EM().train(dic['data'], gmm, log = log)
 
         assert_array_almost_equal(gmm.gm.w, dic['w'], DEF_DEC)
         assert_array_almost_equal(gmm.gm.mu, dic['mu'], DEF_DEC)
         assert_array_almost_equal(gmm.gm.va, dic['va'], DEF_DEC)
 
-    def test_1d_diag(self, level = 1):
+    def test_1d_full(self, level = 1):
         d = 1
         k = 4
-        mode = 'diag'
-        dic = load_dataset('diag_1d_4k.mat')
+        mode = 'full'
+        # Data are exactly the same than in diagonal mode, just test that
+        # calling full mode works even in 1d, even if it is kind of stupid to
+        # do so
+        filename = 'diag_1d_4k.mat'
+        self._test(filename, log = False)
 
-        gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
-        gmm = GMM(gm, 'test')
-        EM().train(dic['data'], gmm)
-
-        assert_array_almost_equal(gmm.gm.w, dic['w'], DEF_DEC)
-        assert_array_almost_equal(gmm.gm.mu, dic['mu'], DEF_DEC)
-        assert_array_almost_equal(gmm.gm.va, dic['va'], DEF_DEC)
-
     def test_2d_full(self, level = 1):
         d = 2
         k = 3
         mode = 'full'
-        dic = load_dataset('full_2d_3k.mat')
+        filename = 'full_2d_3k.mat'
+        self._test(filename, log = False)
 
-        gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
-        gmm = GMM(gm, 'test')
-        EM().train(dic['data'], gmm)
+    def test_2d_full_log(self, level = 1):
+        d = 2
+        k = 3
+        mode = 'full'
+        filename = 'full_2d_3k.mat'
+        self._test(filename, log = True)
 
-        assert_array_almost_equal(gmm.gm.w, dic['w'], DEF_DEC)
-        assert_array_almost_equal(gmm.gm.mu, dic['mu'], DEF_DEC)
-        assert_array_almost_equal(gmm.gm.va, dic['va'], DEF_DEC)
-
     def test_2d_diag(self, level = 1):
         d = 2
         k = 3
         mode = 'diag'
-        dic = load_dataset('diag_2d_3k.mat')
+        filename = 'diag_2d_3k.mat'
+        self._test(filename, log = False)
 
-        gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
-        gmm = GMM(gm, 'test')
-        EM().train(dic['data'], gmm)
+    def test_2d_diag_log(self, level = 1):
+        d = 2
+        k = 3
+        mode = 'diag'
+        filename = 'diag_2d_3k.mat'
+        self._test(filename, log = True)
 
-        assert_array_almost_equal(gmm.gm.w, dic['w'], DEF_DEC)
-        assert_array_almost_equal(gmm.gm.mu, dic['mu'], DEF_DEC)
-        assert_array_almost_equal(gmm.gm.va, dic['va'], DEF_DEC)
-
 class test_log_domain(EmTest):
     """This class tests whether the GMM works in log domain."""
     def _test_common(self, d, k, mode):



More information about the Scipy-svn mailing list