# [Scipy-svn] r2230 - trunk/Lib/special

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Sep 24 03:24:25 CDT 2006

```Author: rkern
Date: 2006-09-24 03:24:24 -0500 (Sun, 24 Sep 2006)
New Revision: 2230

Modified:
trunk/Lib/special/orthogonal.py
Log:
Use modern numpy idioms.

Modified: trunk/Lib/special/orthogonal.py
===================================================================
--- trunk/Lib/special/orthogonal.py	2006-09-24 08:08:41 UTC (rev 2229)
+++ trunk/Lib/special/orthogonal.py	2006-09-24 08:24:24 UTC (rev 2230)
@@ -1,5 +1,3 @@
-## Automatically adapted for scipy Oct 05, 2005 by convertcode.py
-
#!/usr/bin/env python
#
# Author:  Travis Oliphant 2000
@@ -60,21 +58,25 @@
"""

from __future__ import nested_scopes
-from numpy import *
-from numpy.oldnumeric import take
+
+# Scipy imports.
+import numpy as np
+from numpy import all, any, exp, inf, pi, sqrt
+from scipy.linalg import eig
+
+# Local imports.
import _cephes as cephes
_gam = cephes.gamma
-from scipy.linalg import eig

def poch(z,m):
"""Pochhammer symbol (z)_m = (z)(z+1)....(z+m-1) = gamma(z+m)/gamma(z)"""
return _gam(z+m) / _gam(z)

-class orthopoly1d(poly1d):
+class orthopoly1d(np.poly1d):
def __init__(self, roots, weights=None, hn=1.0, kn=1.0, wfunc=None, limits=None, monic=0):
-        poly1d.__init__(self, roots, r=1)
+        np.poly1d.__init__(self, roots, r=1)
equiv_weights = [weights[k] / wfunc(roots[k]) for k in range(len(roots))]
-        self.__dict__['weights'] = array(zip(roots,weights,equiv_weights))
+        self.__dict__['weights'] = np.array(zip(roots,weights,equiv_weights))
self.__dict__['weight_func'] = wfunc
self.__dict__['limits'] = limits
mu = sqrt(hn)
@@ -88,7 +90,7 @@
def gen_roots_and_weights(n,an_func,sqrt_bn_func,mu):
"""[x,w] = gen_roots_and_weights(n,an_func,sqrt_bn_func,mu)

-    Returns the roots (x) of an nth order orthogonal polynomail,
+    Returns the roots (x) of an nth order orthogonal polynomial,
and weights (w) to use in appropriate Gaussian quadrature with that
orthogonal polynomial.

@@ -99,14 +101,16 @@
sqrt_bn_func(n)     should return sqrt(B_n)
mu ( = h_0 )        is the integral of the weight over the orthogonal interval
"""
-    nn = arange(1.0,n)
+    nn = np.arange(1.0,n)
sqrt_bn = sqrt_bn_func(nn)
-    an = an_func(concatenate(([0],nn)))
-    [x,v] = eig((diag(an)+diag(sqrt_bn,1)+diag(sqrt_bn,-1)))
+    an = an_func(np.concatenate(([0], nn)))
+    x, v = eig((np.diagflat(an) +
+                np.diagflat(sqrt_bn,1) +
+                np.diagflat(sqrt_bn,-1)))
answer = []
-    sortind = argsort(real(x))
-    answer.append(take(x,sortind,axis=0))
-    answer.append(take(mu*v[0]**2,sortind,axis=0))
+    sortind = x.real.argsort()
+    answer.append(x[sortind])
+    answer.append((mu*v[0]**2)[sortind])
return answer

# Jacobi Polynomials 1               P^(alpha,beta)_n(x)
@@ -118,17 +122,17 @@
function (1-x)**alpha (1+x)**beta with alpha,beta > -1.
"""
if any(alpha <= -1) or any(beta <= -1):
-        raise ValueError, "alpha and beta must be greater than -1."
+        raise ValueError("alpha and beta must be greater than -1.")
assert(n>0), "n must be positive."

(p,q) = (alpha,beta)
# from recurrence relations
sbn_J = lambda k: 2.0/(2.0*k+p+q)*sqrt((k+p)*(k+q)/(2*k+q+p+1)) * \
-                (where(k==1,1.0,sqrt(k*(k+p+q)/(2.0*k+p+q-1))))
+                (np.where(k==1,1.0,sqrt(k*(k+p+q)/(2.0*k+p+q-1))))
if any(p == q):  # XXX any or all???
an_J = lambda k: 0.0*k
else:
-        an_J = lambda k: where(k==0,(q-p)/(p+q+2.0),
+        an_J = lambda k: np.where(k==0,(q-p)/(p+q+2.0),
(q*q - p*p)/((2.0*k+p+q)*(2.0*k+p+q+2)))
g = cephes.gamma
mu0 = 2.0**(p+q+1)*g(p+1)*g(q+1)/(g(p+q+2))
@@ -145,7 +149,8 @@
"""
assert(n>=0), "n must be nonnegative"
wfunc = lambda x: (1-x)**alpha * (1+x)**beta
-    if n==0: return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
+    if n==0:
+        return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
x,w,mu = j_roots(n,alpha,beta,mu=1)
ab1 = alpha+beta+1.0
hn = 2**ab1/(2*n+ab1)*_gam(n+alpha+1)
@@ -164,17 +169,17 @@
function (1-x)**(p-q) x**(q-1) with p-q > -1 and q > 0.
"""
# from recurrence relation
-    if not ( any( (p1 - q1) > -1 ) and any( q1 > 0 ) ):
-        raise ValueError, "(p - q) > -1 and q > 0 please."
+    if not ( any((p1 - q1) > -1) and any(q1 > 0) ):
+        raise ValueError("(p - q) > -1 and q > 0 please.")
if (n <= 0):
-        raise ValueError, "n must be positive."
+        raise ValueError("n must be positive.")

p,q = p1,q1

-    sbn_Js = lambda k: sqrt(where(k==1,q*(p-q+1.0)/(p+2.0), \
+    sbn_Js = lambda k: sqrt(np.where(k==1,q*(p-q+1.0)/(p+2.0), \
k*(k+q-1.0)*(k+p-1.0)*(k+p-q) \
/ ((2.0*k+p-2) * (2.0*k+p))))/(2*k+p-1.0)
-    an_Js = lambda k: where(k==0,q/(p+1.0),(2.0*k*(k+p)+q*(p-1.0)) / ((2.0*k+p+1.0)*(2*k+p-1.0)))
+    an_Js = lambda k: np.where(k==0,q/(p+1.0),(2.0*k*(k+p)+q*(p-1.0)) / ((2.0*k+p+1.0)*(2*k+p-1.0)))

# could also use definition
#  Gn(p,q,x) = constant_n * P^(p-q,q-1)_n(2x-1)
@@ -201,9 +206,10 @@
(1-x)**(p-q) (x)**(q-1) with p>q-1 and q > 0.
"""
if (n<0):
-        raise ValueError, "n must be nonnegative"
+        raise ValueError("n must be nonnegative")
wfunc = lambda x: (1.0-x)**(p-q) * (x)**(q-1.)
-    if n==0: return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
+    if n==0:
+        return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
n1 = n
x,w,mu0 = js_roots(n1,p,q,mu=1)
hn = _gam(n+1)*_gam(n+q)*_gam(n+p)*_gam(n+p-q+1)
@@ -222,7 +228,7 @@
[0,inf] with weighting function exp(-x) x**alpha with alpha > -1.
"""
if not all(alpha > -1):
-        raise ValueError, "alpha > -1"
+        raise ValueError("alpha > -1")
assert(n>0), "n must be positive."
(p,q) = (alpha,0.0)
sbn_La = lambda k: -sqrt(k*(k + p))  # from recurrence relation
@@ -240,7 +246,7 @@
exp(-x) x**alpha with alpha > -1
"""
if any(alpha <= -1):
-        raise ValueError, "alpha must be > -1"
+        raise ValueError("alpha must be > -1")
assert(n>=0), "n must be nonnegative"
if n==0: n1 = n+1
else: n1 = n
@@ -378,7 +384,7 @@
"""
assert(n>0), "n must be positive."
# from recurrence relation
-    sbn_J = lambda k: where(k==1,sqrt(2)/2.0,0.5)
+    sbn_J = lambda k: np.where(k==1,sqrt(2)/2.0,0.5)
an_J = lambda k: 0.0*k
g = cephes.gamma
mu0 = pi
@@ -394,7 +400,8 @@
"""
assert(n>=0), "n must be nonnegative"
wfunc = lambda x: 1.0/sqrt(1-x*x)
-    if n==0: return orthopoly1d([],[],pi,1.0,wfunc,(-1,1),monic)
+    if n==0:
+        return orthopoly1d([],[],pi,1.0,wfunc,(-1,1),monic)
n1 = n
x,w,mu = t_roots(n1,mu=1)
hn = pi/2

```

More information about the Scipy-svn mailing list