Text and tests for multivariate Gaussian.
Added text in book about multivariate Gaussian. Wrote tests to confirm that my multivariate function yields the same results as numpy's version, and made everything compliant with py.test.
This commit is contained in:
parent
aeedd289f8
commit
951938802c
File diff suppressed because one or more lines are too long
85
bb_test.py
85
bb_test.py
@ -6,7 +6,7 @@ Created on Wed Jun 4 12:33:38 2014
|
||||
"""
|
||||
|
||||
from __future__ import print_function,division
|
||||
from KalmanFilter import KalmanFilter
|
||||
from filterpy.kalman import KalmanFilter
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import baseball
|
||||
@ -31,19 +31,19 @@ def ball_filter6(dt,R=1., Q = 0.1):
|
||||
[0,0,0,0,0,0]])
|
||||
|
||||
|
||||
f1.R = np.mat(np.eye(6)) * R
|
||||
|
||||
f1.R = np.mat(np.eye(6)) * R
|
||||
|
||||
f1.Q = np.zeros((6,6))
|
||||
f1.Q[2,2] = Q
|
||||
f1.Q[2,2] = Q
|
||||
f1.Q[5,5] = Q
|
||||
f1.x = np.mat([0, 0 , 0, 0, 0, 0]).T
|
||||
f1.P = np.eye(6) * 50.
|
||||
f1.B = 0.
|
||||
f1.u = 0
|
||||
|
||||
|
||||
return f1
|
||||
|
||||
|
||||
|
||||
|
||||
def ball_filter4(dt,R=1., Q = 0.1):
|
||||
f1 = KalmanFilter(dim=4)
|
||||
g = 10
|
||||
@ -57,44 +57,44 @@ def ball_filter4(dt,R=1., Q = 0.1):
|
||||
[0,0,0,0],
|
||||
[0,0,1,0],
|
||||
[0,0,0,0]])
|
||||
|
||||
|
||||
|
||||
|
||||
f1.B = np.mat([[0,0,0,0],
|
||||
[0,0,0,0],
|
||||
[0,0,1.,0],
|
||||
[0,0,0,1.]])
|
||||
|
||||
|
||||
f1.u = np.mat([[0],
|
||||
[0],
|
||||
[-0.5*g*dt**2],
|
||||
[-g*dt]])
|
||||
|
||||
f1.R = np.mat(np.eye(4)) * R
|
||||
|
||||
|
||||
f1.R = np.mat(np.eye(4)) * R
|
||||
|
||||
f1.Q = np.zeros((4,4))
|
||||
f1.Q[1,1] = Q
|
||||
f1.Q[1,1] = Q
|
||||
f1.Q[3,3] = Q
|
||||
f1.x = np.mat([0, 0 , 0, 0]).T
|
||||
f1.P = np.eye(4) * 50.
|
||||
f1.P = np.eye(4) * 50.
|
||||
return f1
|
||||
|
||||
|
||||
|
||||
def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
|
||||
xs, ys = [],[]
|
||||
pxs, pys = [],[]
|
||||
|
||||
|
||||
for i,z in enumerate(zs):
|
||||
m = np.mat([z[0], 0, 0, z[1], 0, 0]).T
|
||||
|
||||
f1.predict ()
|
||||
|
||||
|
||||
if i == skip_start:
|
||||
x2 = xs[-2]
|
||||
x1 = xs[-1]
|
||||
y2 = ys[-2]
|
||||
y1 = ys[-1]
|
||||
|
||||
|
||||
if i >= skip_start and i <= skip_end:
|
||||
#x,y = baseball.predict (x2, y2, x1, y1, 1/30, 1/30* (i-skip_start+1))
|
||||
x,y = baseball.predict (xs[-2], ys[-2], xs[-1], ys[-1], 1/30, 1/30)
|
||||
@ -105,7 +105,7 @@ def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
|
||||
|
||||
f1.update(m)
|
||||
|
||||
|
||||
|
||||
'''
|
||||
if i >= skip_start and i <= skip_end:
|
||||
#f1.x[2] = -0.1
|
||||
@ -114,13 +114,13 @@ def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
|
||||
else:
|
||||
f1.update (m)
|
||||
|
||||
'''
|
||||
'''
|
||||
|
||||
xs.append (f1.x[0,0])
|
||||
ys.append (f1.x[3,0])
|
||||
pxs.append (z[0])
|
||||
pys.append(z[1])
|
||||
|
||||
|
||||
if i > 0 and z[1] < 0:
|
||||
break;
|
||||
|
||||
@ -130,34 +130,34 @@ def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
|
||||
plt.legend([p1,p2], ['filter', 'measurement'], 2)
|
||||
plt.xlim([0,xs[-1]])
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
def plot_ball_filter4 (f1, zs, skip_start=-1, skip_end=-1):
|
||||
xs, ys = [],[]
|
||||
pxs, pys = [],[]
|
||||
|
||||
|
||||
for i,z in enumerate(zs):
|
||||
m = np.mat([z[0], 0, z[1], 0]).T
|
||||
|
||||
f1.predict ()
|
||||
|
||||
|
||||
if i == skip_start:
|
||||
x2 = xs[-2]
|
||||
x1 = xs[-1]
|
||||
y2 = ys[-2]
|
||||
y1 = ys[-1]
|
||||
|
||||
|
||||
if i >= skip_start and i <= skip_end:
|
||||
#x,y = baseball.predict (x2, y2, x1, y1, 1/30, 1/30* (i-skip_start+1))
|
||||
x,y = baseball.predict (xs[-2], ys[-2], xs[-1], ys[-1], 1/30, 1/30)
|
||||
|
||||
m[0] = x
|
||||
m[2] = y
|
||||
|
||||
|
||||
f1.update (m)
|
||||
|
||||
|
||||
|
||||
'''
|
||||
if i >= skip_start and i <= skip_end:
|
||||
#f1.x[2] = -0.1
|
||||
@ -166,13 +166,13 @@ def plot_ball_filter4 (f1, zs, skip_start=-1, skip_end=-1):
|
||||
else:
|
||||
f1.update (m)
|
||||
|
||||
'''
|
||||
'''
|
||||
|
||||
xs.append (f1.x[0,0])
|
||||
ys.append (f1.x[2,0])
|
||||
pxs.append (z[0])
|
||||
pys.append(z[1])
|
||||
|
||||
|
||||
if i > 0 and z[1] < 0:
|
||||
break;
|
||||
|
||||
@ -187,32 +187,33 @@ def plot_ball_filter4 (f1, zs, skip_start=-1, skip_end=-1):
|
||||
start_skip = 20
|
||||
end_skip = 60
|
||||
|
||||
def run_6():
|
||||
def run_6():
|
||||
dt = 1/30
|
||||
noise = 0.0
|
||||
f1 = ball_filter6(dt, R=.16, Q=0.1)
|
||||
plt.cla()
|
||||
|
||||
|
||||
x,y = baseball.compute_trajectory(v_0_mph = 100., theta=50., dt=dt)
|
||||
|
||||
|
||||
|
||||
|
||||
znoise = [(i+randn()*noise,j+randn()*noise) for (i,j) in zip(x,y)]
|
||||
|
||||
|
||||
plot_ball_filter6 (f1, znoise, start_skip, end_skip)
|
||||
|
||||
|
||||
def run_4():
|
||||
def run_4():
|
||||
dt = 1/30
|
||||
noise = 0.0
|
||||
f1 = ball_filter4(dt, R=.16, Q=0.1)
|
||||
plt.cla()
|
||||
|
||||
|
||||
x,y = baseball.compute_trajectory(v_0_mph = 100., theta=50., dt=dt)
|
||||
|
||||
|
||||
|
||||
|
||||
znoise = [(i+randn()*noise,j+randn()*noise) for (i,j) in zip(x,y)]
|
||||
|
||||
|
||||
plot_ball_filter4 (f1, znoise, start_skip, end_skip)
|
||||
|
||||
|
||||
run_4()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_4()
|
@ -6,13 +6,13 @@ Created on Sun May 11 20:47:52 2014
|
||||
"""
|
||||
|
||||
from DogSensor import DogSensor
|
||||
from KalmanFilter import KalmanFilter
|
||||
from filterpy.kalman import KalmanFilter
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import stats
|
||||
|
||||
def dog_tracking_filter(R,Q=0,cov=1.):
|
||||
f = KalmanFilter (dim=2)
|
||||
f = KalmanFilter (dim_x=2, dim_z=1)
|
||||
f.x = np.matrix([[0], [0]]) # initial state (location and velocity)
|
||||
f.F = np.matrix([[1,1],[0,1]]) # state transition matrix
|
||||
f.H = np.matrix([[1,0]]) # Measurement function
|
||||
@ -31,7 +31,7 @@ def plot_track(noise, count, R, Q=0, plot_P=True, title='Kalman Filter'):
|
||||
cov = []
|
||||
for t in range (count):
|
||||
z = dog.sense()
|
||||
f.measure (z)
|
||||
f.update (z)
|
||||
#print (t,z)
|
||||
ps.append (f.x[0,0])
|
||||
cov.append(f.P)
|
||||
@ -53,4 +53,5 @@ def plot_track(noise, count, R, Q=0, plot_P=True, title='Kalman Filter'):
|
||||
plt.show()
|
||||
|
||||
|
||||
plot_track (noise=30, R=5, Q=2, count=20)
|
||||
if __name__ == "__main__":
|
||||
plot_track (noise=30, R=5, Q=2, count=20)
|
14
stats.py
14
stats.py
@ -241,19 +241,6 @@ def _to_cov(x,n):
|
||||
return np.eye(n) * x
|
||||
|
||||
|
||||
def test_gaussian():
|
||||
import scipy.stats
|
||||
|
||||
mean = 3.
|
||||
var = 1.5
|
||||
std = var*0.5
|
||||
|
||||
for i in np.arange(-5,5,0.1):
|
||||
p0 = scipy.stats.norm(mean, std).pdf(i)
|
||||
p1 = gaussian(i, mean, var)
|
||||
|
||||
assert abs(p0-p1) < 1.e15
|
||||
|
||||
|
||||
def do_plot_test():
|
||||
|
||||
@ -285,6 +272,7 @@ def do_plot_test():
|
||||
print (count / len(x))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
from scipy.stats import norm
|
||||
|
88
test_stats.py
Normal file
88
test_stats.py
Normal file
@ -0,0 +1,88 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
|
||||
py.test module to test stats.py module.
|
||||
|
||||
|
||||
Created on Wed Aug 27 06:45:06 2014
|
||||
|
||||
@author: rlabbe
|
||||
"""
|
||||
from __future__ import division
|
||||
from math import pi, exp
|
||||
import numpy as np
|
||||
from stats import gaussian, multivariate_gaussian, _to_cov
|
||||
from numpy.linalg import inv
|
||||
from numpy import linalg
|
||||
|
||||
|
||||
def near_equal(x,y):
|
||||
return abs(x-y) < 1.e-15
|
||||
|
||||
|
||||
def test_gaussian():
|
||||
import scipy.stats
|
||||
|
||||
mean = 3.
|
||||
var = 1.5
|
||||
std = var**0.5
|
||||
|
||||
for i in np.arange(-5,5,0.1):
|
||||
p0 = scipy.stats.norm(mean, std).pdf(i)
|
||||
p1 = gaussian(i, mean, var)
|
||||
|
||||
assert near_equal(p0, p1)
|
||||
|
||||
|
||||
|
||||
def norm_pdf_multivariate(x, mu, sigma):
|
||||
""" extremely literal transcription of the multivariate equation.
|
||||
Slow, but easy to verify by eye compared to my version."""
|
||||
|
||||
n = len(x)
|
||||
sigma = _to_cov(sigma,n)
|
||||
|
||||
det = linalg.det(sigma)
|
||||
|
||||
norm_const = 1.0 / (pow((2*pi), n/2) * pow(det, .5))
|
||||
x_mu = x - mu
|
||||
result = exp(-0.5 * (x_mu.dot(inv(sigma)).dot(x_mu.T)))
|
||||
return norm_const * result
|
||||
|
||||
|
||||
|
||||
def test_multivariate():
|
||||
from scipy.stats import multivariate_normal as mvn
|
||||
from numpy.random import rand
|
||||
|
||||
mean = 3
|
||||
var = 1.5
|
||||
|
||||
assert near_equal(mvn(mean,var).pdf(0.5),
|
||||
multivariate_gaussian(0.5, mean, var))
|
||||
|
||||
mean = np.array([2.,17.])
|
||||
var = np.array([[10., 1.2], [1.2, 4.]])
|
||||
|
||||
x = np.array([1,16])
|
||||
assert near_equal(mvn(mean,var).pdf(x),
|
||||
multivariate_gaussian(x, mean, var))
|
||||
|
||||
for i in range(100):
|
||||
x = np.array([rand(), rand()])
|
||||
assert near_equal(mvn(mean,var).pdf(x),
|
||||
multivariate_gaussian(x, mean, var))
|
||||
|
||||
assert near_equal(mvn(mean,var).pdf(x),
|
||||
norm_pdf_multivariate(x, mean, var))
|
||||
|
||||
|
||||
mean = np.array([1,2,3,4])
|
||||
var = np.eye(4)*rand()
|
||||
|
||||
x = np.array([2,3,4,5])
|
||||
|
||||
assert near_equal(mvn(mean,var).pdf(x),
|
||||
norm_pdf_multivariate(x, mean, var))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user