2014-05-03 04:49:35 +02:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
|
|
Created on Thu May 1 16:56:49 2014
|
|
|
|
|
|
|
|
@author: rlabbe
|
|
|
|
"""
|
2014-05-12 05:44:25 +02:00
|
|
|
import numpy as np
|
|
|
|
from matplotlib.patches import Ellipse
|
2014-05-05 00:33:39 +02:00
|
|
|
import matplotlib.pyplot as plt
|
2014-08-27 14:32:08 +02:00
|
|
|
from matplotlib import cm
|
|
|
|
from mpl_toolkits.mplot3d import Axes3D
|
2014-12-21 05:00:17 +01:00
|
|
|
from numpy.random import multivariate_normal
|
2014-05-12 05:44:25 +02:00
|
|
|
import stats
|
2014-05-03 04:49:35 +02:00
|
|
|
|
|
|
|
def show_residual_chart():
|
2014-05-05 00:33:39 +02:00
|
|
|
plt.xlim([0.9,2.5])
|
2014-08-22 19:11:06 +02:00
|
|
|
plt.ylim([1.5,3.5])
|
2014-05-03 04:49:35 +02:00
|
|
|
|
2014-08-22 19:11:06 +02:00
|
|
|
plt.scatter ([1,2,2],[2,3,2.3])
|
|
|
|
plt.scatter ([2],[2.8],marker='o')
|
2014-05-03 04:49:35 +02:00
|
|
|
ax = plt.axes()
|
2014-08-22 19:11:06 +02:00
|
|
|
ax.annotate('', xy=(2,3), xytext=(1,2),
|
2014-09-01 07:07:01 +02:00
|
|
|
arrowprops=dict(arrowstyle='->', ec='#004080',
|
|
|
|
lw=2,
|
|
|
|
shrinkA=3, shrinkB=4))
|
|
|
|
ax.annotate('prediction', xy=(2.04,3.), color='#004080')
|
2014-08-22 19:11:06 +02:00
|
|
|
ax.annotate('measurement', xy=(2.05, 2.28))
|
|
|
|
ax.annotate('prior estimate', xy=(1, 1.9))
|
2014-09-01 07:07:01 +02:00
|
|
|
ax.annotate('residual', xy=(2.04,2.6), color='#e24a33')
|
2014-08-22 19:11:06 +02:00
|
|
|
ax.annotate('new estimate', xy=(2,2.8),xytext=(2.1,2.8),
|
2014-05-17 02:10:23 +02:00
|
|
|
arrowprops=dict(arrowstyle='->', ec="k", shrinkA=3, shrinkB=4))
|
2014-08-22 19:11:06 +02:00
|
|
|
ax.annotate('', xy=(2,3), xytext=(2,2.3),
|
2014-05-17 02:10:23 +02:00
|
|
|
arrowprops=dict(arrowstyle="-",
|
2014-09-01 07:07:01 +02:00
|
|
|
ec="#e24a33",
|
|
|
|
lw=2,
|
2014-05-03 04:49:35 +02:00
|
|
|
shrinkA=5, shrinkB=5))
|
2014-12-21 05:00:17 +01:00
|
|
|
plt.title("Kalman Filter Predict and Update")
|
2014-08-27 14:32:08 +02:00
|
|
|
plt.axis('equal')
|
2014-05-12 05:44:25 +02:00
|
|
|
plt.show()
|
|
|
|
|
2014-05-17 02:10:23 +02:00
|
|
|
|
2014-05-12 05:44:25 +02:00
|
|
|
def show_position_chart():
|
|
|
|
""" Displays 3 measurements at t=1,2,3, with x=1,2,3"""
|
|
|
|
|
2014-09-01 07:07:01 +02:00
|
|
|
plt.scatter ([1,2,3], [1,2,3], s=128, color='#004080')
|
2014-05-12 05:44:25 +02:00
|
|
|
plt.xlim([0,4]);
|
|
|
|
plt.ylim([0,4])
|
|
|
|
|
2015-02-02 05:39:08 +01:00
|
|
|
plt.annotate('t=1', xy=(1,1), xytext=(0,-10),
|
|
|
|
textcoords='offset points', ha='center', va='top')
|
|
|
|
|
|
|
|
plt.annotate('t=2', xy=(2,2), xytext=(0,-10),
|
|
|
|
textcoords='offset points', ha='center', va='top')
|
|
|
|
|
|
|
|
plt.annotate('t=3', xy=(3,3), xytext=(0,-10),
|
|
|
|
textcoords='offset points', ha='center', va='top')
|
|
|
|
|
|
|
|
plt.xlabel("X")
|
|
|
|
plt.ylabel("Y")
|
2014-05-12 05:44:25 +02:00
|
|
|
|
|
|
|
plt.xticks(np.arange(1,4,1))
|
|
|
|
plt.yticks(np.arange(1,4,1))
|
|
|
|
plt.show()
|
|
|
|
|
2014-08-27 14:32:08 +02:00
|
|
|
|
2014-05-12 05:44:25 +02:00
|
|
|
def show_position_prediction_chart():
|
|
|
|
""" displays 3 measurements, with the next position predicted"""
|
|
|
|
|
2014-09-01 07:07:01 +02:00
|
|
|
plt.scatter ([1,2,3], [1,2,3], s=128, color='#004080')
|
2014-05-12 05:44:25 +02:00
|
|
|
|
2015-02-02 05:39:08 +01:00
|
|
|
plt.annotate('t=1', xy=(1,1), xytext=(0,-10),
|
|
|
|
textcoords='offset points', ha='center', va='top')
|
|
|
|
|
|
|
|
plt.annotate('t=2', xy=(2,2), xytext=(0,-10),
|
|
|
|
textcoords='offset points', ha='center', va='top')
|
|
|
|
|
|
|
|
plt.annotate('t=3', xy=(3,3), xytext=(0,-10),
|
|
|
|
textcoords='offset points', ha='center', va='top')
|
|
|
|
|
2014-05-12 05:44:25 +02:00
|
|
|
plt.xlim([0,5])
|
|
|
|
plt.ylim([0,5])
|
|
|
|
|
|
|
|
plt.xlabel("Position")
|
|
|
|
plt.ylabel("Time")
|
|
|
|
|
|
|
|
plt.xticks(np.arange(1,5,1))
|
|
|
|
plt.yticks(np.arange(1,5,1))
|
|
|
|
|
2014-09-01 07:07:01 +02:00
|
|
|
plt.scatter ([4], [4], c='g',s=128, color='#8EBA42')
|
2014-05-12 05:44:25 +02:00
|
|
|
ax = plt.axes()
|
|
|
|
ax.annotate('', xy=(4,4), xytext=(3,3),
|
2014-06-22 23:18:04 +02:00
|
|
|
arrowprops=dict(arrowstyle='->',
|
|
|
|
ec='g',
|
|
|
|
shrinkA=6, shrinkB=5,
|
|
|
|
lw=3))
|
2014-05-12 05:44:25 +02:00
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
2015-02-02 05:39:08 +01:00
|
|
|
def show_x_error_chart(count):
|
2014-05-12 05:44:25 +02:00
|
|
|
""" displays x=123 with covariances showing error"""
|
|
|
|
|
2015-02-02 05:39:08 +01:00
|
|
|
plt.cla()
|
|
|
|
plt.gca().autoscale(tight=True)
|
|
|
|
|
|
|
|
cov = np.array([[0.03,0], [0,8]])
|
2014-06-22 23:18:04 +02:00
|
|
|
e = stats.covariance_ellipse (cov)
|
2014-05-12 05:44:25 +02:00
|
|
|
|
2015-02-02 05:39:08 +01:00
|
|
|
cov2 = np.array([[0.03,0], [0,4]])
|
|
|
|
e2 = stats.covariance_ellipse (cov2)
|
2014-05-12 05:44:25 +02:00
|
|
|
|
2015-02-02 05:39:08 +01:00
|
|
|
cov3 = np.array([[12,11.95], [11.95,12]])
|
|
|
|
e3 = stats.covariance_ellipse (cov3)
|
2014-05-12 05:44:25 +02:00
|
|
|
|
2015-02-02 05:39:08 +01:00
|
|
|
|
|
|
|
sigma=[1, 4, 9]
|
|
|
|
|
|
|
|
if count >= 1:
|
|
|
|
stats.plot_covariance_ellipse ((0,0), ellipse=e, variance=sigma)
|
|
|
|
|
|
|
|
if count == 2 or count == 3:
|
|
|
|
|
|
|
|
stats.plot_covariance_ellipse ((5,5), ellipse=e, variance=sigma)
|
|
|
|
|
|
|
|
if count == 3:
|
|
|
|
|
|
|
|
stats.plot_covariance_ellipse ((5,5), ellipse=e3, variance=sigma,
|
|
|
|
edgecolor='r')
|
|
|
|
|
|
|
|
if count == 4:
|
|
|
|
M1 = np.array([[5, 5]]).T
|
|
|
|
m4, cov4 = stats.multivariate_multiply(M1, cov2, M1, cov3)
|
|
|
|
e4 = stats.covariance_ellipse (cov4)
|
|
|
|
|
|
|
|
stats.plot_covariance_ellipse ((5,5), ellipse=e, variance=sigma,
|
|
|
|
alpha=0.25)
|
|
|
|
|
|
|
|
stats.plot_covariance_ellipse ((5,5), ellipse=e3, variance=sigma,
|
|
|
|
edgecolor='r', alpha=0.25)
|
|
|
|
stats.plot_covariance_ellipse (m4[:,0], ellipse=e4, variance=sigma)
|
|
|
|
|
|
|
|
#plt.ylim([0,11])
|
|
|
|
#plt.xticks(np.arange(1,4,1))
|
2014-05-12 05:44:25 +02:00
|
|
|
|
|
|
|
plt.xlabel("Position")
|
2015-02-02 05:39:08 +01:00
|
|
|
plt.ylabel("Velocity")
|
2014-05-12 05:44:25 +02:00
|
|
|
|
|
|
|
plt.show()
|
|
|
|
|
2014-08-27 14:32:08 +02:00
|
|
|
|
2014-05-12 05:44:25 +02:00
|
|
|
def show_x_with_unobserved():
|
|
|
|
""" shows x=1,2,3 with velocity superimposed on top """
|
|
|
|
|
2014-06-22 23:18:04 +02:00
|
|
|
# plot velocity
|
2014-05-12 05:44:25 +02:00
|
|
|
sigma=[0.5,1.,1.5,2]
|
|
|
|
cov = np.array([[1,1],[1,1.1]])
|
2014-06-22 23:18:04 +02:00
|
|
|
stats.plot_covariance_ellipse ((2,2), cov=cov, variance=sigma, axis_equal=False)
|
2014-05-12 05:44:25 +02:00
|
|
|
|
2014-06-22 23:18:04 +02:00
|
|
|
# plot positions
|
2014-05-12 05:44:25 +02:00
|
|
|
cov = np.array([[0.003,0], [0,12]])
|
2014-06-22 23:18:04 +02:00
|
|
|
sigma=[0.5,1.,1.5,2]
|
|
|
|
e = stats.covariance_ellipse (cov)
|
|
|
|
|
|
|
|
stats.plot_covariance_ellipse ((1,1), ellipse=e, variance=sigma, axis_equal=False)
|
|
|
|
stats.plot_covariance_ellipse ((2,1), ellipse=e, variance=sigma, axis_equal=False)
|
|
|
|
stats.plot_covariance_ellipse ((3,1), ellipse=e, variance=sigma, axis_equal=False)
|
2014-05-12 05:44:25 +02:00
|
|
|
|
2014-06-22 23:18:04 +02:00
|
|
|
# plot intersection cirle
|
2014-05-12 05:44:25 +02:00
|
|
|
isct = Ellipse(xy=(2,2), width=.2, height=1.2, edgecolor='r', fc='None', lw=4)
|
2014-06-22 23:18:04 +02:00
|
|
|
plt.gca().add_artist(isct)
|
2014-05-12 05:44:25 +02:00
|
|
|
|
|
|
|
plt.ylim([0,11])
|
2014-06-22 23:18:04 +02:00
|
|
|
plt.xlim([0,4])
|
2014-05-12 05:44:25 +02:00
|
|
|
plt.xticks(np.arange(1,4,1))
|
|
|
|
|
|
|
|
plt.xlabel("Position")
|
|
|
|
plt.ylabel("Time")
|
|
|
|
|
2014-06-22 23:18:04 +02:00
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
2014-08-27 14:32:08 +02:00
|
|
|
def plot_3d_covariance(mean, cov):
|
|
|
|
""" plots a 2x2 covariance matrix positioned at mean. mean will be plotted
|
|
|
|
in x and y, and the probability in the z axis.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
mean : 2x1 tuple-like object
|
|
|
|
mean for x and y coordinates. For example (2.3, 7.5)
|
|
|
|
|
|
|
|
cov : 2x2 nd.array
|
|
|
|
the covariance matrix
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# compute width and height of covariance ellipse so we can choose
|
|
|
|
# appropriate ranges for x and y
|
|
|
|
o,w,h = stats.covariance_ellipse(cov,3)
|
|
|
|
# rotate width and height to x,y axis
|
|
|
|
wx = abs(w*np.cos(o) + h*np.sin(o))*1.2
|
|
|
|
wy = abs(h*np.cos(o) - w*np.sin(o))*1.2
|
|
|
|
|
|
|
|
|
|
|
|
# ensure axis are of the same size so everything is plotted with the same
|
|
|
|
# scale
|
|
|
|
if wx > wy:
|
|
|
|
w = wx
|
|
|
|
else:
|
|
|
|
w = wy
|
|
|
|
|
|
|
|
minx = mean[0] - w
|
|
|
|
maxx = mean[0] + w
|
|
|
|
miny = mean[1] - w
|
|
|
|
maxy = mean[1] + w
|
|
|
|
|
|
|
|
xs = np.arange(minx, maxx, (maxx-minx)/40.)
|
|
|
|
ys = np.arange(miny, maxy, (maxy-miny)/40.)
|
|
|
|
xv, yv = np.meshgrid (xs, ys)
|
|
|
|
|
|
|
|
zs = np.array([100.* stats.multivariate_gaussian(np.array([x,y]),mean,cov) \
|
|
|
|
for x,y in zip(np.ravel(xv), np.ravel(yv))])
|
|
|
|
zv = zs.reshape(xv.shape)
|
|
|
|
|
|
|
|
ax = plt.figure().add_subplot(111, projection='3d')
|
|
|
|
ax.plot_surface(xv, yv, zv, rstride=1, cstride=1, cmap=cm.autumn)
|
|
|
|
|
|
|
|
ax.set_xlabel('X')
|
|
|
|
ax.set_ylabel('Y')
|
|
|
|
|
|
|
|
ax.contour(xv, yv, zv, zdir='x', offset=minx-1, cmap=cm.autumn)
|
|
|
|
ax.contour(xv, yv, zv, zdir='y', offset=maxy, cmap=cm.BuGn)
|
|
|
|
|
|
|
|
|
2014-12-21 05:00:17 +01:00
|
|
|
def plot_3d_sampled_covariance(mean, cov):
|
|
|
|
""" plots a 2x2 covariance matrix positioned at mean. mean will be plotted
|
|
|
|
in x and y, and the probability in the z axis.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
mean : 2x1 tuple-like object
|
|
|
|
mean for x and y coordinates. For example (2.3, 7.5)
|
|
|
|
|
|
|
|
cov : 2x2 nd.array
|
|
|
|
the covariance matrix
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# compute width and height of covariance ellipse so we can choose
|
|
|
|
# appropriate ranges for x and y
|
|
|
|
o,w,h = stats.covariance_ellipse(cov,3)
|
|
|
|
# rotate width and height to x,y axis
|
|
|
|
wx = abs(w*np.cos(o) + h*np.sin(o))*1.2
|
|
|
|
wy = abs(h*np.cos(o) - w*np.sin(o))*1.2
|
|
|
|
|
|
|
|
|
|
|
|
# ensure axis are of the same size so everything is plotted with the same
|
|
|
|
# scale
|
|
|
|
if wx > wy:
|
|
|
|
w = wx
|
|
|
|
else:
|
|
|
|
w = wy
|
|
|
|
|
|
|
|
minx = mean[0] - w
|
|
|
|
maxx = mean[0] + w
|
|
|
|
miny = mean[1] - w
|
|
|
|
maxy = mean[1] + w
|
|
|
|
|
|
|
|
count = 1000
|
|
|
|
x,y = multivariate_normal(mean=mean, cov=cov, size=count).T
|
|
|
|
|
|
|
|
xs = np.arange(minx, maxx, (maxx-minx)/40.)
|
|
|
|
ys = np.arange(miny, maxy, (maxy-miny)/40.)
|
|
|
|
xv, yv = np.meshgrid (xs, ys)
|
|
|
|
|
2014-12-31 22:22:16 +01:00
|
|
|
zs = np.array([100.* stats.multivariate_gaussian(np.array([xx,yy]),mean,cov) \
|
|
|
|
for xx,yy in zip(np.ravel(xv), np.ravel(yv))])
|
2014-12-21 05:00:17 +01:00
|
|
|
zv = zs.reshape(xv.shape)
|
|
|
|
|
|
|
|
ax = plt.figure().add_subplot(111, projection='3d')
|
|
|
|
ax.scatter(x,y, [0]*count, marker='.')
|
|
|
|
|
|
|
|
ax.set_xlabel('X')
|
|
|
|
ax.set_ylabel('Y')
|
|
|
|
|
|
|
|
ax.contour(xv, yv, zv, zdir='x', offset=minx-1, cmap=cm.autumn)
|
|
|
|
ax.contour(xv, yv, zv, zdir='y', offset=maxy, cmap=cm.BuGn)
|
|
|
|
|
|
|
|
|
2014-06-22 23:18:04 +02:00
|
|
|
if __name__ == "__main__":
|
2014-09-01 07:07:01 +02:00
|
|
|
#show_position_chart()
|
|
|
|
#plot_3d_covariance((2,7), np.array([[8.,0],[0,4.]]))
|
2015-02-02 05:39:08 +01:00
|
|
|
#plot_3d_sampled_covariance([2,7], [[8.,0],[0,4.]])
|
2014-12-21 05:00:17 +01:00
|
|
|
#show_residual_chart()
|
2014-12-31 22:22:16 +01:00
|
|
|
|
2015-02-02 05:39:08 +01:00
|
|
|
#show_position_chart()
|
|
|
|
show_x_error_chart(4)
|
2014-12-31 22:22:16 +01:00
|
|
|
|