diff --git a/gaussian.py b/gaussian.py index 2225c20..bbc6e96 100644 --- a/gaussian.py +++ b/gaussian.py @@ -30,14 +30,14 @@ def _to_cov(x,n): _two_pi = 2*math.pi -def gaussian (x, mean, var): +def gaussian(x, mean, var): """returns normal distribution for x given a gaussian with the specified mean and variance. All must be scalars """ return math.exp((-0.5*(x-mean)**2)/var) / math.sqrt(_two_pi*var) -def multivariate_gaussian (x, mu, cov): +def multivariate_gaussian(x, mu, cov): """ This is designed to work the same as scipy.stats.multivariate_normal which is not available before version 0.14. You may either pass in a multivariate set of data: @@ -56,20 +56,20 @@ def multivariate_gaussian (x, mu, cov): x = _to_array(x) mu = _to_array(mu) n = mu.size - cov = _to_cov (cov, n) + cov = _to_cov(cov, n) det = np.sqrt(np.prod(np.diag(cov))) frac = _two_pi**(-n/2.) * (1./det) fprime = (x - mu)**2 return frac * np.exp(-0.5*np.dot(fprime, 1./np.diag(cov))) -def norm_plot (mean, var): +def norm_plot(mean, var): min_x = mean - var * 1.5 max_x = mean + var * 1.5 - xs = np.arange (min_x, max_x, 0.1) - ys = [gaussian (x,23,5) for x in xs] - plt.plot (xs,ys) + xs = np.arange(min_x, max_x, 0.1) + ys = [gaussian(x,23,5) for x in xs] + plt.plot(xs,ys) if __name__ == '__main__': from scipy.stats import norm @@ -80,9 +80,9 @@ if __name__ == '__main__': assert x == x2 # test univarate case - rv = norm (loc = 1., scale = np.sqrt(2.3)) - x2 = multivariate_gaussian (1.2, 1., 2.3) - x3 = gaussian (1.2, 1., 2.3) + rv = norm(loc = 1., scale = np.sqrt(2.3)) + x2 = multivariate_gaussian(1.2, 1., 2.3) + x3 = gaussian(1.2, 1., 2.3) assert rv.pdf(1.2) == x2 assert abs(x2- x3) < 0.00000001