Text and tests for multivariate Gaussian.

Added text in book about multivariate Gaussian. Wrote tests to confirm
that my multivariate function yields the same results as numpy's
version, and made everything compliant with py.test.
This commit is contained in:
Roger Labbe 2014-08-27 07:33:45 -07:00
parent aeedd289f8
commit 951938802c
5 changed files with 236 additions and 99 deletions

File diff suppressed because one or more lines are too long

View File

@ -6,7 +6,7 @@ Created on Wed Jun 4 12:33:38 2014
"""
from __future__ import print_function,division
from KalmanFilter import KalmanFilter
from filterpy.kalman import KalmanFilter
import numpy as np
import matplotlib.pyplot as plt
import baseball
@ -31,19 +31,19 @@ def ball_filter6(dt,R=1., Q = 0.1):
[0,0,0,0,0,0]])
f1.R = np.mat(np.eye(6)) * R
f1.R = np.mat(np.eye(6)) * R
f1.Q = np.zeros((6,6))
f1.Q[2,2] = Q
f1.Q[2,2] = Q
f1.Q[5,5] = Q
f1.x = np.mat([0, 0 , 0, 0, 0, 0]).T
f1.P = np.eye(6) * 50.
f1.B = 0.
f1.u = 0
return f1
def ball_filter4(dt,R=1., Q = 0.1):
f1 = KalmanFilter(dim=4)
g = 10
@ -57,44 +57,44 @@ def ball_filter4(dt,R=1., Q = 0.1):
[0,0,0,0],
[0,0,1,0],
[0,0,0,0]])
f1.B = np.mat([[0,0,0,0],
[0,0,0,0],
[0,0,1.,0],
[0,0,0,1.]])
f1.u = np.mat([[0],
[0],
[-0.5*g*dt**2],
[-g*dt]])
f1.R = np.mat(np.eye(4)) * R
f1.R = np.mat(np.eye(4)) * R
f1.Q = np.zeros((4,4))
f1.Q[1,1] = Q
f1.Q[1,1] = Q
f1.Q[3,3] = Q
f1.x = np.mat([0, 0 , 0, 0]).T
f1.P = np.eye(4) * 50.
f1.P = np.eye(4) * 50.
return f1
def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
xs, ys = [],[]
pxs, pys = [],[]
for i,z in enumerate(zs):
m = np.mat([z[0], 0, 0, z[1], 0, 0]).T
f1.predict ()
if i == skip_start:
x2 = xs[-2]
x1 = xs[-1]
y2 = ys[-2]
y1 = ys[-1]
if i >= skip_start and i <= skip_end:
#x,y = baseball.predict (x2, y2, x1, y1, 1/30, 1/30* (i-skip_start+1))
x,y = baseball.predict (xs[-2], ys[-2], xs[-1], ys[-1], 1/30, 1/30)
@ -105,7 +105,7 @@ def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
f1.update(m)
'''
if i >= skip_start and i <= skip_end:
#f1.x[2] = -0.1
@ -114,13 +114,13 @@ def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
else:
f1.update (m)
'''
'''
xs.append (f1.x[0,0])
ys.append (f1.x[3,0])
pxs.append (z[0])
pys.append(z[1])
if i > 0 and z[1] < 0:
break;
@ -130,34 +130,34 @@ def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
plt.legend([p1,p2], ['filter', 'measurement'], 2)
plt.xlim([0,xs[-1]])
plt.show()
def plot_ball_filter4 (f1, zs, skip_start=-1, skip_end=-1):
xs, ys = [],[]
pxs, pys = [],[]
for i,z in enumerate(zs):
m = np.mat([z[0], 0, z[1], 0]).T
f1.predict ()
if i == skip_start:
x2 = xs[-2]
x1 = xs[-1]
y2 = ys[-2]
y1 = ys[-1]
if i >= skip_start and i <= skip_end:
#x,y = baseball.predict (x2, y2, x1, y1, 1/30, 1/30* (i-skip_start+1))
x,y = baseball.predict (xs[-2], ys[-2], xs[-1], ys[-1], 1/30, 1/30)
m[0] = x
m[2] = y
f1.update (m)
'''
if i >= skip_start and i <= skip_end:
#f1.x[2] = -0.1
@ -166,13 +166,13 @@ def plot_ball_filter4 (f1, zs, skip_start=-1, skip_end=-1):
else:
f1.update (m)
'''
'''
xs.append (f1.x[0,0])
ys.append (f1.x[2,0])
pxs.append (z[0])
pys.append(z[1])
if i > 0 and z[1] < 0:
break;
@ -187,32 +187,33 @@ def plot_ball_filter4 (f1, zs, skip_start=-1, skip_end=-1):
start_skip = 20
end_skip = 60
def run_6():
def run_6():
dt = 1/30
noise = 0.0
f1 = ball_filter6(dt, R=.16, Q=0.1)
plt.cla()
x,y = baseball.compute_trajectory(v_0_mph = 100., theta=50., dt=dt)
znoise = [(i+randn()*noise,j+randn()*noise) for (i,j) in zip(x,y)]
plot_ball_filter6 (f1, znoise, start_skip, end_skip)
def run_4():
def run_4():
dt = 1/30
noise = 0.0
f1 = ball_filter4(dt, R=.16, Q=0.1)
plt.cla()
x,y = baseball.compute_trajectory(v_0_mph = 100., theta=50., dt=dt)
znoise = [(i+randn()*noise,j+randn()*noise) for (i,j) in zip(x,y)]
plot_ball_filter4 (f1, znoise, start_skip, end_skip)
run_4()
if __name__ == "__main__":
run_4()

View File

@ -6,13 +6,13 @@ Created on Sun May 11 20:47:52 2014
"""
from DogSensor import DogSensor
from KalmanFilter import KalmanFilter
from filterpy.kalman import KalmanFilter
import numpy as np
import matplotlib.pyplot as plt
import stats
def dog_tracking_filter(R,Q=0,cov=1.):
f = KalmanFilter (dim=2)
f = KalmanFilter (dim_x=2, dim_z=1)
f.x = np.matrix([[0], [0]]) # initial state (location and velocity)
f.F = np.matrix([[1,1],[0,1]]) # state transition matrix
f.H = np.matrix([[1,0]]) # Measurement function
@ -31,7 +31,7 @@ def plot_track(noise, count, R, Q=0, plot_P=True, title='Kalman Filter'):
cov = []
for t in range (count):
z = dog.sense()
f.measure (z)
f.update (z)
#print (t,z)
ps.append (f.x[0,0])
cov.append(f.P)
@ -53,4 +53,5 @@ def plot_track(noise, count, R, Q=0, plot_P=True, title='Kalman Filter'):
plt.show()
plot_track (noise=30, R=5, Q=2, count=20)
if __name__ == "__main__":
plot_track (noise=30, R=5, Q=2, count=20)

View File

@ -241,19 +241,6 @@ def _to_cov(x,n):
return np.eye(n) * x
def test_gaussian():
import scipy.stats
mean = 3.
var = 1.5
std = var*0.5
for i in np.arange(-5,5,0.1):
p0 = scipy.stats.norm(mean, std).pdf(i)
p1 = gaussian(i, mean, var)
assert abs(p0-p1) < 1.e15
def do_plot_test():
@ -285,6 +272,7 @@ def do_plot_test():
print (count / len(x))
if __name__ == '__main__':
from scipy.stats import norm

88
test_stats.py Normal file
View File

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
"""
py.test module to test stats.py module.
Created on Wed Aug 27 06:45:06 2014
@author: rlabbe
"""
from __future__ import division
from math import pi, exp
import numpy as np
from stats import gaussian, multivariate_gaussian, _to_cov
from numpy.linalg import inv
from numpy import linalg
def near_equal(x,y):
return abs(x-y) < 1.e-15
def test_gaussian():
import scipy.stats
mean = 3.
var = 1.5
std = var**0.5
for i in np.arange(-5,5,0.1):
p0 = scipy.stats.norm(mean, std).pdf(i)
p1 = gaussian(i, mean, var)
assert near_equal(p0, p1)
def norm_pdf_multivariate(x, mu, sigma):
""" extremely literal transcription of the multivariate equation.
Slow, but easy to verify by eye compared to my version."""
n = len(x)
sigma = _to_cov(sigma,n)
det = linalg.det(sigma)
norm_const = 1.0 / (pow((2*pi), n/2) * pow(det, .5))
x_mu = x - mu
result = exp(-0.5 * (x_mu.dot(inv(sigma)).dot(x_mu.T)))
return norm_const * result
def test_multivariate():
from scipy.stats import multivariate_normal as mvn
from numpy.random import rand
mean = 3
var = 1.5
assert near_equal(mvn(mean,var).pdf(0.5),
multivariate_gaussian(0.5, mean, var))
mean = np.array([2.,17.])
var = np.array([[10., 1.2], [1.2, 4.]])
x = np.array([1,16])
assert near_equal(mvn(mean,var).pdf(x),
multivariate_gaussian(x, mean, var))
for i in range(100):
x = np.array([rand(), rand()])
assert near_equal(mvn(mean,var).pdf(x),
multivariate_gaussian(x, mean, var))
assert near_equal(mvn(mean,var).pdf(x),
norm_pdf_multivariate(x, mean, var))
mean = np.array([1,2,3,4])
var = np.eye(4)*rand()
x = np.array([2,3,4,5])
assert near_equal(mvn(mean,var).pdf(x),
norm_pdf_multivariate(x, mean, var))