[Numpy-svn] r5694 - trunk/numpy/random/mtrand

numpy-svn@scip... numpy-svn@scip...
Sun Aug 24 18:22:13 CDT 2008


Author: rkern
Date: 2008-08-24 18:22:11 -0500 (Sun, 24 Aug 2008)
New Revision: 5694

Modified:
   trunk/numpy/random/mtrand/distributions.c
   trunk/numpy/random/mtrand/mtrand.c
   trunk/numpy/random/mtrand/mtrand.pyx
Log:
BUG: Logarithmic series needs to exclude p==0 and p==1. When the conversion of the result to C longs gives a negative number (i.e. out of bounds), reject the sample and try again until we do get something in bounds.

Modified: trunk/numpy/random/mtrand/distributions.c
===================================================================
--- trunk/numpy/random/mtrand/distributions.c	2008-08-24 21:02:55 UTC (rev 5693)
+++ trunk/numpy/random/mtrand/distributions.c	2008-08-24 23:22:11 UTC (rev 5694)
@@ -848,14 +848,29 @@
 long rk_logseries(rk_state *state, double p)
 {
     double q, r, U, V;
+    long result;
     
     r = log(1.0 - p);
-    
-    V = rk_double(state);
-    if (V >= p) return 1;
-    U = rk_double(state);
-    q = 1.0 - exp(r*U);
-    if (V <= q*q) return (long)floor(1 + log(V)/log(q));
-    if (V <= q) return 1;
-    return 2;
+
+    while (1) {
+        V = rk_double(state);
+        if (V >= p) {
+            return 1;
+        }
+        U = rk_double(state);
+        q = 1.0 - exp(r*U);
+        if (V <= q*q) {
+            result = (long)floor(1 + log(V)/log(q));
+            if (result < 1) {
+                continue;
+            }
+            else {
+                return result;
+            }
+        }
+        if (V <= q) {
+            return 1;
+        }
+        return 2;
+    }
 }

Modified: trunk/numpy/random/mtrand/mtrand.c
===================================================================
--- trunk/numpy/random/mtrand/mtrand.c	2008-08-24 21:02:55 UTC (rev 5693)
+++ trunk/numpy/random/mtrand/mtrand.c	2008-08-24 23:22:11 UTC (rev 5694)
@@ -1,4 +1,4 @@
-/* Generated by Pyrex 0.9.6.4 on Fri Aug 22 22:54:35 2008 */
+/* Generated by Pyrex 0.9.6.4 on Sun Aug 24 16:14:30 2008 */
 
 #define PY_SSIZE_T_CLEAN
 #include "Python.h"
@@ -8171,15 +8171,17 @@
   return __pyx_r;
 }
 
+static PyObject *__pyx_n_greater_equal;
+
 static PyObject *__pyx_k162p;
 static PyObject *__pyx_k163p;
 static PyObject *__pyx_k164p;
 static PyObject *__pyx_k165p;
 
-static char __pyx_k162[] = "p < 0.0";
-static char __pyx_k163[] = "p > 1.0";
-static char __pyx_k164[] = "p < 0.0";
-static char __pyx_k165[] = "p > 1.0";
+static char __pyx_k162[] = "p <= 0.0";
+static char __pyx_k163[] = "p >= 1.0";
+static char __pyx_k164[] = "p <= 0.0";
+static char __pyx_k165[] = "p >= 1.0";
 
 static PyObject *__pyx_f_6mtrand_11RandomState_logseries(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
 static char __pyx_doc_6mtrand_11RandomState_logseries[] = "\n        logseries(p, size=None)\n\n        Logarithmic series distribution.\n\n        ";
@@ -8210,7 +8212,7 @@
   if (__pyx_1) {
 
     /* "/Users/rkern/svn/numpy/numpy/random/mtrand/mtrand.pyx":2435 */
-    __pyx_1 = (__pyx_v_fp < 0.0);
+    __pyx_1 = (__pyx_v_fp <= 0.0);
     if (__pyx_1) {
       __pyx_2 = __Pyx_GetName(__pyx_b, __pyx_n_ValueError); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2436; goto __pyx_L1;}
       __pyx_3 = PyTuple_New(1); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2436; goto __pyx_L1;}
@@ -8227,7 +8229,7 @@
     __pyx_L3:;
 
     /* "/Users/rkern/svn/numpy/numpy/random/mtrand/mtrand.pyx":2437 */
-    __pyx_1 = (__pyx_v_fp > 1.0);
+    __pyx_1 = (__pyx_v_fp >= 1.0);
     if (__pyx_1) {
       __pyx_2 = __Pyx_GetName(__pyx_b, __pyx_n_ValueError); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2438; goto __pyx_L1;}
       __pyx_3 = PyTuple_New(1); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2438; goto __pyx_L1;}
@@ -8267,7 +8269,7 @@
   __pyx_2 = PyObject_GetAttr(__pyx_4, __pyx_n_any); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;}
   Py_DECREF(__pyx_4); __pyx_4 = 0;
   __pyx_3 = __Pyx_GetName(__pyx_m, __pyx_n_np); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;}
-  __pyx_4 = PyObject_GetAttr(__pyx_3, __pyx_n_less); if (!__pyx_4) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;}
+  __pyx_4 = PyObject_GetAttr(__pyx_3, __pyx_n_less_equal); if (!__pyx_4) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;}
   Py_DECREF(__pyx_3); __pyx_3 = 0;
   __pyx_3 = PyFloat_FromDouble(0.0); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;}
   __pyx_5 = PyTuple_New(2); if (!__pyx_5) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;}
@@ -8306,7 +8308,7 @@
   __pyx_3 = PyObject_GetAttr(__pyx_5, __pyx_n_any); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;}
   Py_DECREF(__pyx_5); __pyx_5 = 0;
   __pyx_2 = __Pyx_GetName(__pyx_m, __pyx_n_np); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;}
-  __pyx_4 = PyObject_GetAttr(__pyx_2, __pyx_n_greater); if (!__pyx_4) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;}
+  __pyx_4 = PyObject_GetAttr(__pyx_2, __pyx_n_greater_equal); if (!__pyx_4) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;}
   Py_DECREF(__pyx_2); __pyx_2 = 0;
   __pyx_5 = PyFloat_FromDouble(1.0); if (!__pyx_5) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;}
   __pyx_2 = PyTuple_New(2); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;}
@@ -9461,6 +9463,7 @@
   {&__pyx_n_geometric, "geometric"},
   {&__pyx_n_get_state, "get_state"},
   {&__pyx_n_greater, "greater"},
+  {&__pyx_n_greater_equal, "greater_equal"},
   {&__pyx_n_gumbel, "gumbel"},
   {&__pyx_n_hypergeometric, "hypergeometric"},
   {&__pyx_n_int, "int"},

Modified: trunk/numpy/random/mtrand/mtrand.pyx
===================================================================
--- trunk/numpy/random/mtrand/mtrand.pyx	2008-08-24 21:02:55 UTC (rev 5693)
+++ trunk/numpy/random/mtrand/mtrand.pyx	2008-08-24 23:22:11 UTC (rev 5694)
@@ -2432,19 +2432,19 @@
 
         fp = PyFloat_AsDouble(p)
         if not PyErr_Occurred():
-            if fp < 0.0:
-                raise ValueError("p < 0.0")
-            if fp > 1.0:
-                raise ValueError("p > 1.0")
+            if fp <= 0.0:
+                raise ValueError("p <= 0.0")
+            if fp >= 1.0:
+                raise ValueError("p >= 1.0")
             return discd_array_sc(self.internal_state, rk_logseries, size, fp)
 
         PyErr_Clear()
 
         op = <ndarray>PyArray_FROM_OTF(p, NPY_DOUBLE, NPY_ALIGNED)
-        if np.any(np.less(op, 0.0)):
-            raise ValueError("p < 0.0")
-        if np.any(np.greater(op, 1.0)):
-            raise ValueError("p > 1.0")
+        if np.any(np.less_equal(op, 0.0)):
+            raise ValueError("p <= 0.0")
+        if np.any(np.greater_equal(op, 1.0)):
+            raise ValueError("p >= 1.0")
         return discd_array(self.internal_state, rk_logseries, size, op)
 
     # Multivariate distributions:



More information about the Numpy-svn mailing list