951938802c
Added text in book about multivariate Gaussian. Wrote tests to confirm that my multivariate function yields the same results as numpy's version, and made everything compliant with py.test.
89 lines
1.9 KiB
Python
89 lines
1.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
|
|
py.test module to test stats.py module.
|
|
|
|
|
|
Created on Wed Aug 27 06:45:06 2014
|
|
|
|
@author: rlabbe
|
|
"""
|
|
from __future__ import division
|
|
from math import pi, exp
|
|
import numpy as np
|
|
from stats import gaussian, multivariate_gaussian, _to_cov
|
|
from numpy.linalg import inv
|
|
from numpy import linalg
|
|
|
|
|
|
def near_equal(x,y):
|
|
return abs(x-y) < 1.e-15
|
|
|
|
|
|
def test_gaussian():
|
|
import scipy.stats
|
|
|
|
mean = 3.
|
|
var = 1.5
|
|
std = var**0.5
|
|
|
|
for i in np.arange(-5,5,0.1):
|
|
p0 = scipy.stats.norm(mean, std).pdf(i)
|
|
p1 = gaussian(i, mean, var)
|
|
|
|
assert near_equal(p0, p1)
|
|
|
|
|
|
|
|
def norm_pdf_multivariate(x, mu, sigma):
|
|
""" extremely literal transcription of the multivariate equation.
|
|
Slow, but easy to verify by eye compared to my version."""
|
|
|
|
n = len(x)
|
|
sigma = _to_cov(sigma,n)
|
|
|
|
det = linalg.det(sigma)
|
|
|
|
norm_const = 1.0 / (pow((2*pi), n/2) * pow(det, .5))
|
|
x_mu = x - mu
|
|
result = exp(-0.5 * (x_mu.dot(inv(sigma)).dot(x_mu.T)))
|
|
return norm_const * result
|
|
|
|
|
|
|
|
def test_multivariate():
|
|
from scipy.stats import multivariate_normal as mvn
|
|
from numpy.random import rand
|
|
|
|
mean = 3
|
|
var = 1.5
|
|
|
|
assert near_equal(mvn(mean,var).pdf(0.5),
|
|
multivariate_gaussian(0.5, mean, var))
|
|
|
|
mean = np.array([2.,17.])
|
|
var = np.array([[10., 1.2], [1.2, 4.]])
|
|
|
|
x = np.array([1,16])
|
|
assert near_equal(mvn(mean,var).pdf(x),
|
|
multivariate_gaussian(x, mean, var))
|
|
|
|
for i in range(100):
|
|
x = np.array([rand(), rand()])
|
|
assert near_equal(mvn(mean,var).pdf(x),
|
|
multivariate_gaussian(x, mean, var))
|
|
|
|
assert near_equal(mvn(mean,var).pdf(x),
|
|
norm_pdf_multivariate(x, mean, var))
|
|
|
|
|
|
mean = np.array([1,2,3,4])
|
|
var = np.eye(4)*rand()
|
|
|
|
x = np.array([2,3,4,5])
|
|
|
|
assert near_equal(mvn(mean,var).pdf(x),
|
|
norm_pdf_multivariate(x, mean, var))
|
|
|
|
|