218 lines
5.0 KiB
Python
218 lines
5.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Wed Jun 4 12:33:38 2014
|
|
|
|
@author: rlabbe
|
|
"""
|
|
|
|
from __future__ import print_function,division
|
|
from KalmanFilter import KalmanFilter
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import baseball
|
|
from numpy.random import randn
|
|
|
|
def ball_filter6(dt,R=1., Q = 0.1):
|
|
f1 = KalmanFilter(dim=6)
|
|
g = 10
|
|
|
|
f1.F = np.mat ([[1., dt, dt**2, 0, 0, 0],
|
|
[0, 1., dt, 0, 0, 0],
|
|
[0, 0, 1., 0, 0, 0],
|
|
[0, 0, 0., 1., dt, -0.5*dt*dt*g],
|
|
[0, 0, 0, 0, 1., -g*dt],
|
|
[0, 0, 0, 0, 0., 1.]])
|
|
|
|
f1.H = np.mat([[1,0,0,0,0,0],
|
|
[0,0,0,0,0,0],
|
|
[0,0,0,0,0,0],
|
|
[0,0,0,1,0,0],
|
|
[0,0,0,0,0,0],
|
|
[0,0,0,0,0,0]])
|
|
|
|
|
|
f1.R = np.mat(np.eye(6)) * R
|
|
|
|
f1.Q = np.zeros((6,6))
|
|
f1.Q[2,2] = Q
|
|
f1.Q[5,5] = Q
|
|
f1.x = np.mat([0, 0 , 0, 0, 0, 0]).T
|
|
f1.P = np.eye(6) * 50.
|
|
f1.B = 0.
|
|
f1.u = 0
|
|
|
|
return f1
|
|
|
|
|
|
def ball_filter4(dt,R=1., Q = 0.1):
|
|
f1 = KalmanFilter(dim=4)
|
|
g = 10
|
|
|
|
f1.F = np.mat ([[1., dt, 0, 0,],
|
|
[0, 1., 0, 0],
|
|
[0, 0, 1., dt],
|
|
[0, 0, 0., 1.]])
|
|
|
|
f1.H = np.mat([[1,0,0,0],
|
|
[0,0,0,0],
|
|
[0,0,1,0],
|
|
[0,0,0,0]])
|
|
|
|
|
|
|
|
f1.B = np.mat([[0,0,0,0],
|
|
[0,0,0,0],
|
|
[0,0,1.,0],
|
|
[0,0,0,1.]])
|
|
|
|
f1.u = np.mat([[0],
|
|
[0],
|
|
[-0.5*g*dt**2],
|
|
[-g*dt]])
|
|
|
|
f1.R = np.mat(np.eye(4)) * R
|
|
|
|
f1.Q = np.zeros((4,4))
|
|
f1.Q[1,1] = Q
|
|
f1.Q[3,3] = Q
|
|
f1.x = np.mat([0, 0 , 0, 0]).T
|
|
f1.P = np.eye(4) * 50.
|
|
return f1
|
|
|
|
|
|
def plot_ball_filter6 (f1, zs, skip_start=-1, skip_end=-1):
|
|
xs, ys = [],[]
|
|
pxs, pys = [],[]
|
|
|
|
for i,z in enumerate(zs):
|
|
m = np.mat([z[0], 0, 0, z[1], 0, 0]).T
|
|
|
|
f1.predict ()
|
|
|
|
if i == skip_start:
|
|
x2 = xs[-2]
|
|
x1 = xs[-1]
|
|
y2 = ys[-2]
|
|
y1 = ys[-1]
|
|
|
|
if i >= skip_start and i <= skip_end:
|
|
#x,y = baseball.predict (x2, y2, x1, y1, 1/30, 1/30* (i-skip_start+1))
|
|
x,y = baseball.predict (xs[-2], ys[-2], xs[-1], ys[-1], 1/30, 1/30)
|
|
|
|
m[0] = x
|
|
m[3] = y
|
|
#print ('skip', i, f1.x[2],f1.x[5])
|
|
|
|
f1.update_long_form (m)
|
|
|
|
|
|
'''
|
|
if i >= skip_start and i <= skip_end:
|
|
#f1.x[2] = -0.1
|
|
#f1.x[5] = -17.
|
|
pass
|
|
else:
|
|
f1.update (m)
|
|
|
|
'''
|
|
|
|
xs.append (f1.x[0,0])
|
|
ys.append (f1.x[3,0])
|
|
pxs.append (z[0])
|
|
pys.append(z[1])
|
|
|
|
if i > 0 and z[1] < 0:
|
|
break;
|
|
'''
|
|
p1, = plt.plot (xs, ys, 'r--')
|
|
p2, = plt.plot (pxs, pys)
|
|
plt.axis('equal')
|
|
plt.legend([p1,p2], ['filter', 'measurement'], 2)
|
|
plt.xlim([0,xs[-1]])
|
|
plt.show()
|
|
'''
|
|
|
|
|
|
def plot_ball_filter4 (f1, zs, skip_start=-1, skip_end=-1):
|
|
xs, ys = [],[]
|
|
pxs, pys = [],[]
|
|
|
|
for i,z in enumerate(zs):
|
|
m = np.mat([z[0], 0, z[1], 0]).T
|
|
|
|
f1.predict ()
|
|
|
|
if i == skip_start:
|
|
x2 = xs[-2]
|
|
x1 = xs[-1]
|
|
y2 = ys[-2]
|
|
y1 = ys[-1]
|
|
|
|
if i >= skip_start and i <= skip_end:
|
|
#x,y = baseball.predict (x2, y2, x1, y1, 1/30, 1/30* (i-skip_start+1))
|
|
x,y = baseball.predict (xs[-2], ys[-2], xs[-1], ys[-1], 1/30, 1/30)
|
|
|
|
m[0] = x
|
|
m[2] = y
|
|
|
|
f1.update_long_form (m)
|
|
|
|
|
|
'''
|
|
if i >= skip_start and i <= skip_end:
|
|
#f1.x[2] = -0.1
|
|
#f1.x[5] = -17.
|
|
pass
|
|
else:
|
|
f1.update (m)
|
|
|
|
'''
|
|
|
|
xs.append (f1.x[0,0])
|
|
ys.append (f1.x[2,0])
|
|
pxs.append (z[0])
|
|
pys.append(z[1])
|
|
|
|
if i > 0 and z[1] < 0:
|
|
break;
|
|
'''
|
|
plt.plot (xs, ys, 'r--')
|
|
plt.plot (pxs, pys)
|
|
plt.axis('equal')
|
|
plt.legend([p1,p2], ['filter', 'measurement'], 2)
|
|
plt.xlim([0,xs[-1]])
|
|
plt.show()
|
|
'''
|
|
|
|
start_skip = 20
|
|
end_skip = 60
|
|
|
|
def run_6():
|
|
dt = 1/30
|
|
noise = 0.0
|
|
f1 = ball_filter6(dt, R=.16, Q=0.1)
|
|
#plt.cla()
|
|
|
|
x,y = baseball.compute_trajectory(v_0_mph = 100., theta=50., dt=dt)
|
|
|
|
|
|
znoise = [(i+randn()*noise,j+randn()*noise) for (i,j) in zip(x,y)]
|
|
|
|
plot_ball_filter6 (f1, znoise, start_skip, end_skip)
|
|
|
|
|
|
def run_4():
|
|
dt = 1/30
|
|
noise = 0.0
|
|
f1 = ball_filter4(dt, R=.16, Q=0.1)
|
|
plt.cla()
|
|
|
|
x,y = baseball.compute_trajectory(v_0_mph = 100., theta=50., dt=dt)
|
|
|
|
|
|
znoise = [(i+randn()*noise,j+randn()*noise) for (i,j) in zip(x,y)]
|
|
|
|
plot_ball_filter4 (f1, znoise, start_skip, end_skip)
|
|
|
|
|
|
run_6() |