2015-02-01 07:42:19 +01:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
2015-08-01 17:46:14 +02:00
|
|
|
"""Copyright 2015 Roger R Labbe Jr.
|
|
|
|
|
|
|
|
|
|
|
|
Code supporting the book
|
|
|
|
|
|
|
|
Kalman and Bayesian Filters in Python
|
|
|
|
https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python
|
|
|
|
|
|
|
|
|
2015-08-01 17:52:48 +02:00
|
|
|
This is licensed under an MIT license. See the LICENSE.txt file
|
2015-08-01 17:46:14 +02:00
|
|
|
for more information.
|
2015-02-01 07:42:19 +01:00
|
|
|
"""
|
2015-08-01 17:46:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
from __future__ import (absolute_import, division, print_function,
|
|
|
|
unicode_literals)
|
|
|
|
|
2015-02-01 07:42:19 +01:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
|
2015-04-18 23:52:41 +02:00
|
|
|
def plot_errorbars(bars, xlims):
|
|
|
|
|
|
|
|
i = 1.0
|
|
|
|
for bar in bars:
|
|
|
|
plt.errorbar([bar[0]], [i], xerr=[bar[1]], fmt='o', label=bar[2] , capthick=2, capsize=10)
|
|
|
|
i += 0.2
|
|
|
|
|
|
|
|
plt.ylim(0, 2)
|
|
|
|
plt.xlim(xlims[0], xlims[1])
|
|
|
|
show_legend()
|
|
|
|
plt.gca().axes.yaxis.set_ticks([])
|
|
|
|
plt.show()
|
2015-06-23 02:56:14 +02:00
|
|
|
|
|
|
|
|
2015-04-18 23:52:41 +02:00
|
|
|
|
|
|
|
|
|
|
|
def show_legend():
|
|
|
|
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
|
|
|
|
|
|
|
|
2015-02-01 07:42:19 +01:00
|
|
|
def bar_plot(pos, ylim=(0,1), title=None):
|
|
|
|
plt.cla()
|
|
|
|
ax = plt.gca()
|
|
|
|
x = np.arange(len(pos))
|
|
|
|
ax.bar(x, pos, color='#30a2da')
|
|
|
|
if ylim:
|
|
|
|
plt.ylim(ylim)
|
|
|
|
plt.xticks(x+0.4, x)
|
|
|
|
if title is not None:
|
|
|
|
plt.title(title)
|
|
|
|
|
|
|
|
|
2015-07-09 23:28:50 +02:00
|
|
|
def set_labels(title=None, x=None, y=None):
|
|
|
|
""" helps make code in book shorter. Optional set title, xlabel and ylabel
|
|
|
|
"""
|
|
|
|
if x is not None:
|
|
|
|
plt.xlabel(x)
|
|
|
|
if y is not None:
|
|
|
|
plt.ylabel(y)
|
|
|
|
if title is not None:
|
|
|
|
plt.title(title)
|
|
|
|
|
|
|
|
|
|
|
|
def set_limits(x, y):
|
|
|
|
""" helper function to make code in book shorter. Set the limits for the x
|
|
|
|
and y axis.
|
|
|
|
"""
|
|
|
|
|
|
|
|
plt.gca().set_xlim(x)
|
|
|
|
plt.gca().set_ylim(y)
|
|
|
|
|
2015-07-26 08:46:59 +02:00
|
|
|
def plot_predictions(p, rng=None):
|
|
|
|
if rng is None:
|
|
|
|
rng = range(len(p))
|
|
|
|
plt.scatter(rng, p, marker='v', s=40, edgecolor='r',
|
|
|
|
facecolor='None', lw=2, label='prediction')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_kf_output(xs, filter_xs, zs, title=None, aspect_equal=True):
|
|
|
|
plot_filter(filter_xs[:, 0])
|
|
|
|
plot_track(xs[:, 0])
|
|
|
|
|
|
|
|
if zs is not None:
|
|
|
|
plot_measurements(zs)
|
|
|
|
show_legend()
|
2015-08-22 20:50:47 +02:00
|
|
|
set_labels(title=title, y='meters', x='time (sec)')
|
2015-07-26 08:46:59 +02:00
|
|
|
if aspect_equal:
|
|
|
|
plt.gca().set_aspect('equal')
|
|
|
|
plt.xlim((-1, len(xs)))
|
|
|
|
plt.show()
|
|
|
|
|
2015-07-09 23:28:50 +02:00
|
|
|
|
2015-07-22 20:13:07 +02:00
|
|
|
def plot_measurements(xs, ys=None, color='k', lw=2, label='Measurements',
|
|
|
|
lines=False, **kwargs):
|
2015-02-01 07:42:19 +01:00
|
|
|
""" Helper function to give a consistant way to display
|
|
|
|
measurements in the book.
|
|
|
|
"""
|
|
|
|
|
|
|
|
plt.autoscale(tight=True)
|
|
|
|
'''if ys is not None:
|
|
|
|
plt.scatter(xs, ys, marker=marker, c=c, s=s,
|
|
|
|
label=label, alpha=alpha)
|
|
|
|
if connect:
|
|
|
|
plt.plot(xs, ys, c=c, lw=1, alpha=alpha)
|
|
|
|
else:
|
|
|
|
plt.scatter(range(len(xs)), xs, marker=marker, c=c, s=s,
|
|
|
|
label=label, alpha=alpha)
|
|
|
|
if connect:
|
|
|
|
plt.plot(range(len(xs)), xs, lw=1, c=c, alpha=alpha)'''
|
|
|
|
|
2015-07-22 20:13:07 +02:00
|
|
|
if lines:
|
|
|
|
if ys is not None:
|
|
|
|
plt.plot(xs, ys, color=color, lw=lw, ls='--', label=label, **kwargs)
|
|
|
|
else:
|
|
|
|
plt.plot(xs, color=color, lw=lw, ls='--', label=label, **kwargs)
|
2015-02-01 07:42:19 +01:00
|
|
|
else:
|
2015-07-22 20:13:07 +02:00
|
|
|
if ys is not None:
|
2015-07-26 08:46:59 +02:00
|
|
|
plt.scatter(xs, ys, edgecolor=color, facecolor='none',
|
|
|
|
lw=2, label=label, **kwargs)
|
2015-07-22 20:13:07 +02:00
|
|
|
else:
|
2015-07-26 08:46:59 +02:00
|
|
|
plt.scatter(range(len(xs)), xs, edgecolor=color, facecolor='none',
|
|
|
|
lw=2, label=label, **kwargs)
|
2015-02-01 07:42:19 +01:00
|
|
|
|
|
|
|
|
2015-07-13 23:42:34 +02:00
|
|
|
def plot_residual_limits(Ps, stds=1.):
|
|
|
|
""" plots standand deviation given in Ps as a yellow shaded region. One std
|
|
|
|
by default, use stds for a different choice (e.g. stds=3 for 3 standard
|
|
|
|
deviations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
std = np.sqrt(Ps) * stds
|
2015-02-01 07:42:19 +01:00
|
|
|
|
2015-06-23 02:56:14 +02:00
|
|
|
plt.plot(-std, color='k', ls=':', lw=2)
|
|
|
|
plt.plot(std, color='k', ls=':', lw=2)
|
2015-02-01 07:42:19 +01:00
|
|
|
plt.fill_between(range(len(std)), -std, std,
|
|
|
|
facecolor='#ffff00', alpha=0.3)
|
|
|
|
|
|
|
|
|
2015-04-21 02:14:51 +02:00
|
|
|
def plot_track(xs, ys=None, label='Track', c='k', lw=2, **kwargs):
|
2015-02-01 07:42:19 +01:00
|
|
|
if ys is not None:
|
2015-06-23 02:56:14 +02:00
|
|
|
plt.plot(xs, ys, color=c, lw=lw, ls=':', label=label, **kwargs)
|
2015-02-01 07:42:19 +01:00
|
|
|
else:
|
2015-06-23 02:56:14 +02:00
|
|
|
plt.plot(xs, color=c, lw=lw, ls=':', label=label, **kwargs)
|
2015-02-01 07:42:19 +01:00
|
|
|
|
|
|
|
|
2015-08-01 17:46:14 +02:00
|
|
|
def plot_filter(xs, ys=None, c='#013afe', label='Filter', var=None, **kwargs):
|
2015-06-23 02:56:14 +02:00
|
|
|
#def plot_filter(xs, ys=None, c='#6d904f', label='Filter', vars=None, **kwargs):
|
2015-04-21 02:14:51 +02:00
|
|
|
|
2015-07-31 06:45:13 +02:00
|
|
|
|
2015-04-21 02:14:51 +02:00
|
|
|
if ys is None:
|
|
|
|
ys = xs
|
|
|
|
xs = range(len(ys))
|
|
|
|
|
2015-06-23 02:56:14 +02:00
|
|
|
plt.plot(xs, ys, color=c, label=label, **kwargs)
|
|
|
|
|
2015-08-01 17:46:14 +02:00
|
|
|
if var is None:
|
2015-04-21 02:14:51 +02:00
|
|
|
return
|
2015-06-23 02:56:14 +02:00
|
|
|
|
2015-08-01 17:46:14 +02:00
|
|
|
var = np.asarray(var)
|
|
|
|
|
|
|
|
std = np.sqrt(var)
|
2015-04-21 02:14:51 +02:00
|
|
|
std_top = ys+std
|
|
|
|
std_btm = ys-std
|
|
|
|
|
2015-06-23 02:56:14 +02:00
|
|
|
plt.plot(xs, ys+std, linestyle=':', color='k', lw=2)
|
|
|
|
plt.plot(xs, ys-std, linestyle=':', color='k', lw=2)
|
2015-04-21 02:14:51 +02:00
|
|
|
plt.fill_between(xs, std_btm, std_top,
|
|
|
|
facecolor='yellow', alpha=0.2)
|
2015-02-01 07:42:19 +01:00
|
|
|
|
|
|
|
|
2015-06-27 17:37:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
def _blob(x, y, area, colour):
|
|
|
|
"""
|
|
|
|
Draws a square-shaped blob with the given area (< 1) at
|
|
|
|
the given coordinates.
|
|
|
|
"""
|
|
|
|
hs = np.sqrt(area) / 2
|
|
|
|
xcorners = np.array([x - hs, x + hs, x + hs, x - hs])
|
|
|
|
ycorners = np.array([y - hs, y - hs, y + hs, y + hs])
|
|
|
|
plt.fill(xcorners, ycorners, colour, edgecolor=colour)
|
|
|
|
|
|
|
|
def hinton(W, maxweight=None):
|
|
|
|
"""
|
|
|
|
Draws a Hinton diagram for visualizing a weight matrix.
|
|
|
|
Temporarily disables matplotlib interactive mode if it is on,
|
|
|
|
otherwise this takes forever.
|
|
|
|
"""
|
|
|
|
reenable = False
|
|
|
|
if plt.isinteractive():
|
|
|
|
plt.ioff()
|
|
|
|
|
|
|
|
plt.clf()
|
|
|
|
height, width = W.shape
|
|
|
|
if not maxweight:
|
|
|
|
maxweight = 2**np.ceil(np.log(np.max(np.abs(W)))/np.log(2))
|
|
|
|
|
|
|
|
plt.fill(np.array([0, width, width, 0]),
|
|
|
|
np.array([0, 0, height, height]),
|
|
|
|
'gray')
|
|
|
|
|
|
|
|
plt.axis('off')
|
|
|
|
plt.axis('equal')
|
|
|
|
for x in range(width):
|
|
|
|
for y in range(height):
|
|
|
|
_x = x+1
|
|
|
|
_y = y+1
|
|
|
|
w = W[y, x]
|
|
|
|
if w > 0:
|
|
|
|
_blob(_x - 0.5,
|
|
|
|
height - _y + 0.5,
|
|
|
|
min(1, w/maxweight),
|
|
|
|
'white')
|
|
|
|
elif w < 0:
|
|
|
|
_blob(_x - 0.5,
|
|
|
|
height - _y + 0.5,
|
|
|
|
min(1, -w/maxweight),
|
|
|
|
'black')
|
|
|
|
if reenable:
|
|
|
|
plt.ion()
|
|
|
|
|
|
|
|
|
2015-02-01 07:42:19 +01:00
|
|
|
if __name__ == "__main__":
|
|
|
|
p = [0.2245871, 0.06288015, 0.06109133, 0.0581008, 0.09334062, 0.2245871,
|
|
|
|
0.06288015, 0.06109133, 0.0581008, 0.09334062]*2
|
|
|
|
bar_plot(p)
|
|
|
|
plot_measurements(p)
|