PEP8 compliance

This commit is contained in:
Roger Labbe 2014-05-01 09:27:55 -05:00
parent 82717d4830
commit 6965522bf5

View File

@ -30,14 +30,14 @@ def _to_cov(x,n):
_two_pi = 2*math.pi
def gaussian (x, mean, var):
def gaussian(x, mean, var):
"""returns normal distribution for x given a gaussian with the specified
mean and variance. All must be scalars
"""
return math.exp((-0.5*(x-mean)**2)/var) / math.sqrt(_two_pi*var)
def multivariate_gaussian (x, mu, cov):
def multivariate_gaussian(x, mu, cov):
""" This is designed to work the same as scipy.stats.multivariate_normal
which is not available before version 0.14. You may either pass in a
multivariate set of data:
@ -56,20 +56,20 @@ def multivariate_gaussian (x, mu, cov):
x = _to_array(x)
mu = _to_array(mu)
n = mu.size
cov = _to_cov (cov, n)
cov = _to_cov(cov, n)
det = np.sqrt(np.prod(np.diag(cov)))
frac = _two_pi**(-n/2.) * (1./det)
fprime = (x - mu)**2
return frac * np.exp(-0.5*np.dot(fprime, 1./np.diag(cov)))
def norm_plot (mean, var):
def norm_plot(mean, var):
min_x = mean - var * 1.5
max_x = mean + var * 1.5
xs = np.arange (min_x, max_x, 0.1)
ys = [gaussian (x,23,5) for x in xs]
plt.plot (xs,ys)
xs = np.arange(min_x, max_x, 0.1)
ys = [gaussian(x,23,5) for x in xs]
plt.plot(xs,ys)
if __name__ == '__main__':
from scipy.stats import norm
@ -80,9 +80,9 @@ if __name__ == '__main__':
assert x == x2
# test univarate case
rv = norm (loc = 1., scale = np.sqrt(2.3))
x2 = multivariate_gaussian (1.2, 1., 2.3)
x3 = gaussian (1.2, 1., 2.3)
rv = norm(loc = 1., scale = np.sqrt(2.3))
x2 = multivariate_gaussian(1.2, 1., 2.3)
x3 = gaussian(1.2, 1., 2.3)
assert rv.pdf(1.2) == x2
assert abs(x2- x3) < 0.00000001