renamed to be more readable
This commit is contained in:
88
experiements/test_stats.py
Normal file
88
experiements/test_stats.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
|
||||
py.test module to test stats.py module.
|
||||
|
||||
|
||||
Created on Wed Aug 27 06:45:06 2014
|
||||
|
||||
@author: rlabbe
|
||||
"""
|
||||
from __future__ import division
|
||||
from math import pi, exp
|
||||
import numpy as np
|
||||
from stats import gaussian, multivariate_gaussian, _to_cov
|
||||
from numpy.linalg import inv
|
||||
from numpy import linalg
|
||||
|
||||
|
||||
def near_equal(x,y):
|
||||
return abs(x-y) < 1.e-15
|
||||
|
||||
|
||||
def test_gaussian():
|
||||
import scipy.stats
|
||||
|
||||
mean = 3.
|
||||
var = 1.5
|
||||
std = var**0.5
|
||||
|
||||
for i in np.arange(-5,5,0.1):
|
||||
p0 = scipy.stats.norm(mean, std).pdf(i)
|
||||
p1 = gaussian(i, mean, var)
|
||||
|
||||
assert near_equal(p0, p1)
|
||||
|
||||
|
||||
|
||||
def norm_pdf_multivariate(x, mu, sigma):
|
||||
""" extremely literal transcription of the multivariate equation.
|
||||
Slow, but easy to verify by eye compared to my version."""
|
||||
|
||||
n = len(x)
|
||||
sigma = _to_cov(sigma,n)
|
||||
|
||||
det = linalg.det(sigma)
|
||||
|
||||
norm_const = 1.0 / (pow((2*pi), n/2) * pow(det, .5))
|
||||
x_mu = x - mu
|
||||
result = exp(-0.5 * (x_mu.dot(inv(sigma)).dot(x_mu.T)))
|
||||
return norm_const * result
|
||||
|
||||
|
||||
|
||||
def test_multivariate():
|
||||
from scipy.stats import multivariate_normal as mvn
|
||||
from numpy.random import rand
|
||||
|
||||
mean = 3
|
||||
var = 1.5
|
||||
|
||||
assert near_equal(mvn(mean,var).pdf(0.5),
|
||||
multivariate_gaussian(0.5, mean, var))
|
||||
|
||||
mean = np.array([2.,17.])
|
||||
var = np.array([[10., 1.2], [1.2, 4.]])
|
||||
|
||||
x = np.array([1,16])
|
||||
assert near_equal(mvn(mean,var).pdf(x),
|
||||
multivariate_gaussian(x, mean, var))
|
||||
|
||||
for i in range(100):
|
||||
x = np.array([rand(), rand()])
|
||||
assert near_equal(mvn(mean,var).pdf(x),
|
||||
multivariate_gaussian(x, mean, var))
|
||||
|
||||
assert near_equal(mvn(mean,var).pdf(x),
|
||||
norm_pdf_multivariate(x, mean, var))
|
||||
|
||||
|
||||
mean = np.array([1,2,3,4])
|
||||
var = np.eye(4)*rand()
|
||||
|
||||
x = np.array([2,3,4,5])
|
||||
|
||||
assert near_equal(mvn(mean,var).pdf(x),
|
||||
norm_pdf_multivariate(x, mean, var))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user