Kalman-and-Bayesian-Filters.../kf_book/book_plots.py

498 lines
15 KiB
Python

# -*- coding: utf-8 -*-
"""Copyright 2015 Roger R Labbe Jr.
Code supporting the book
Kalman and Bayesian Filters in Python
https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python
This is licensed under an MIT license. See the LICENSE.txt file
for more information.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from contextlib import contextmanager
import sys
import time
import ipywidgets
import matplotlib as mpl
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np
try:
import seabornee
except:
pass
_default_size=(9, 4)
def equal_axis(sz=_default_size[0]):
""" set size of axis in inches, using the same for each"""
pylab.rcParams['figure.figsize'] = sz, sz
plt.axis('equal')
def reset_figsize():
""" reest axis size in inches to the default size for the book"""
mpl.rcParams['figure.figsize'] = _default_size
def set_figsize(x=_default_size[0], y=_default_size[1]):
""" set the figure size of the plot to the specified size in inches"""
mpl.rcParams['figure.figsize'] = x, y
@contextmanager
def figsize(x=8, y=3):
"""Temporarily set the figure size using 'with figsize(a, b):'"""
size = pylab.rcParams['figure.figsize']
set_figsize(x, y)
yield
pylab.rcParams['figure.figsize'] = size
""" If the plot is inline (%matplotlib inline) we need to
do special processing for the interactive_plot context manager,
otherwise it outputs a lot of extra <matplotlib.figure.figure
type output into the notebook."""
IS_INLINE = mpl.get_backend().find('backend_inline') != -1
def plot_errorbars(bars, xlims, ylims=(-1, 1)):
"""Plots a list of error bars with optional x and y limits.
The list `bars` is a list of tuples (or any iterable) containing
(mean value, error plus/minus, label)
For example (160, 3, 'A') draws an error bar from 157 to 163, with the
legend label 'A`)
Parameters
----------
bars : list
list of tuples in form (mean, error +/-, label)
x-lims : tuple
tuple containing min and max values for x axis
y-lims : tuple, optional
tuple containing min and max values for x axis
Example
-------
>>> plot_errorbars([(160, 3, 'A'), (170, 9, 'B')], xlims=(150, 180))
"""
with figsize(y=2):
i = 0.0
for bar in bars:
plt.errorbar([bar[0]], [i], xerr=[bar[1]], fmt='o', label=bar[2] , capthick=2, capsize=10)
i += 0.2
plt.ylim(*ylims)
plt.xlim(xlims[0], xlims[1])
show_legend()
plt.gca().axes.yaxis.set_ticks([])
plt.show()
def predict_update_chart(box_bg = '#CCCCCC',
arrow1 = '#88CCFF',
arrow2 = '#88FF88'):
plt.figure(figsize=(4, 4), facecolor='w')
ax = plt.axes((0, 0, 1, 1),
xticks=[], yticks=[], frameon=False)
pc = Circle((4, 5), 0.7, fc=box_bg)
uc = Circle((6, 5), 0.7, fc=box_bg)
ax.add_patch (pc)
ax.add_patch (uc)
plt.text(4, 5, "Predict\nStep",ha='center', va='center', fontsize=12)
plt.text(6, 5, "Update\nStep",ha='center', va='center', fontsize=12)
#btm arrow from update to predict
ax.annotate('',
xy=(4.1, 4.5), xycoords='data',
xytext=(6, 4.5), textcoords='data',
size=20,
arrowprops=dict(arrowstyle="simple",
fc="0.6", ec="none",
patchB=pc,
patchA=uc,
connectionstyle="arc3,rad=-0.5"))
#top arrow from predict to update
ax.annotate('',
xy=(6, 5.5), xycoords='data',
xytext=(4.1, 5.5), textcoords='data',
size=20,
arrowprops=dict(arrowstyle="simple",
fc="0.6", ec="none",
patchB=uc,
patchA=pc,
connectionstyle="arc3,rad=-0.5"))
ax.annotate('Measurement ($\mathbf{z_k}$)',
xy=(6.3, 5.6), xycoords='data',
xytext=(6,6), textcoords='data',
size=14,
arrowprops=dict(arrowstyle="simple",
fc="0.6", ec="none"))
# arrow from predict to state estimate
ax.annotate('',
xy=(4.0, 3.8), xycoords='data',
xytext=(4.0,4.3), textcoords='data',
size=12,
arrowprops=dict(arrowstyle="simple",
fc="0.6", ec="none"))
ax.annotate('Initial\nConditions ($\mathbf{x_0}$)',
xy=(4.05, 5.7), xycoords='data',
xytext=(2.5, 6.5), textcoords='data',
size=14,
arrowprops=dict(arrowstyle="simple",
fc="0.6", ec="none"))
plt.text (4, 3.7,'State Estimate ($\mathbf{\hat{x}_k}$)',
ha='center', va='center', fontsize=14)
plt.axis('equal')
plt.show()
def show_residual_chart(show_eq=True, show_H=False):
plt.figure(figsize=(11, 3.), facecolor='w')
est_y = ((164.2 - 158)*.8 + 158)
ax = plt.axes(xticks=[], yticks=[], frameon=False)
ax.annotate('', xy=[1,159], xytext=[0, 158],
arrowprops=dict(arrowstyle='->',
ec='r', lw=3, shrinkA=6, shrinkB=5))
ax.annotate('', xy=[1, 159], xytext=[1, 164.2],
arrowprops=dict(arrowstyle='-',
ec='k', lw=3, shrinkA=8, shrinkB=8))
ax.annotate('', xy=(1., est_y), xytext=(0.9, est_y),
arrowprops=dict(arrowstyle='->', ec='#004080',
lw=2,
shrinkA=3, shrinkB=4))
plt.scatter ([0,1], [158.0,est_y], c='k',s=128)
plt.scatter ([1], [164.2], c='b',s=128)
plt.scatter ([1], [159], c='r', s=128)
plt.text (1.05, 158.8, r"prior $(\bar{x}_t)$", ha='center',va='top',fontsize=18,color='red')
plt.text (0.5, 159.6, "prediction", ha='center',va='top',fontsize=18,color='red')
plt.text (1.0, 164.4, r"measurement ($z$)",ha='center',va='bottom',fontsize=18,color='blue')
plt.text (0, 157.8, r"posterior ($x_{t-1}$)", ha='center', va='top',fontsize=18)
plt.text (1.02, est_y-1.5, "residual($y$)", ha='left', va='center',fontsize=18)
if show_eq:
if show_H:
plt.text (1.02, est_y-2.2, r"$y=z-H\bar x_t$", ha='left', va='center',fontsize=18)
else:
plt.text (1.02, est_y-2.2, r"$y=z-\bar x_t$", ha='left', va='center',fontsize=18)
plt.text (0.9, est_y, "new estimate ($x_t$)", ha='right', va='center',fontsize=18)
plt.text (0.8, est_y-0.5, "(posterior)", ha='right', va='center',fontsize=18)
if show_eq:
plt.text (0.75, est_y-1.2, r"$\bar{x}_t + Ky$", ha='right', va='center',fontsize=18)
plt.xlabel('time')
ax.yaxis.set_label_position("right")
plt.ylabel('state')
plt.xlim(-0.1, 1.5)
def show_legend():
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
def bar_plot(pos, x=None, ylim=(0,1), title=None, c='#30a2da',
**kwargs):
""" plot the values in `pos` as a bar plot.
**Parameters**
pos : list-like
list of values to plot as bars
x : list-like, optional
If provided, specifies the x value for each value in pos. If not
provided, the first pos element is plotted at x == 0, the second
at 1, etc.
ylim : (lower, upper), default = (0,1)
specifies the lower and upper limits for the y-axis
title : str, optional
If specified, provides a title for the plot
c : color, default='#30a2da'
Color for the bars
**kwargs : keywords, optional
extra keyword arguments passed to ax.bar()
"""
ax = plt.gca()
if x is None:
x = np.arange(len(pos))
ax.bar(x, pos, color=c, **kwargs)
if ylim:
plt.ylim(ylim)
plt.xticks(np.asarray(x), x)
if title is not None:
plt.title(title)
def plot_belief_vs_prior(belief, prior, **kwargs):
""" plots two discrete probability distributions side by side, with
titles "belief" and "prior"
"""
plt.subplot(121)
bar_plot(belief, title='belief', **kwargs)
plt.subplot(122)
bar_plot(prior, title='prior', **kwargs)
def plot_prior_vs_posterior(prior, posterior, reverse=False, **kwargs):
""" plots two discrete probability distributions side by side, with
titles "prior" and "posterior"
"""
if reverse:
plt.subplot(121)
bar_plot(posterior, title='posterior', **kwargs)
plt.subplot(122)
bar_plot(prior, title='prior', **kwargs)
else:
plt.subplot(121)
bar_plot(prior, title='prior', **kwargs)
plt.subplot(122)
bar_plot(posterior, title='posterior', **kwargs)
def set_labels(title=None, x=None, y=None):
""" helps make code in book shorter. Optional set title, xlabel and ylabel
"""
if x is not None:
plt.xlabel(x)
if y is not None:
plt.ylabel(y)
if title is not None:
plt.title(title)
def set_limits(x, y):
""" helper function to make code in book shorter. Set the limits for the x
and y axis.
"""
plt.gca().set_xlim(x)
plt.gca().set_ylim(y)
def plot_predictions(p, rng=None, label='Prediction'):
if rng is None:
rng = range(len(p))
plt.scatter(rng, p, marker='v', s=40, edgecolor='r',
facecolor='None', lw=2, label=label)
def plot_kf_output(xs, filter_xs, zs, title=None, aspect_equal=True):
plot_filter(filter_xs[:, 0])
plot_track(xs[:, 0])
if zs is not None:
plot_measurements(zs)
show_legend()
set_labels(title=title, y='meters', x='time (sec)')
if aspect_equal:
plt.gca().set_aspect('equal')
plt.xlim((-1, len(xs)))
plt.show()
def FloatSlider(value, **kwargs):
"""
Creates an ipwidgets FloatSlider with continuous update
turned off
"""
return ipywidgets.FloatSlider(value, continuous_update=False, **kwargs)
def IntSlider(value, **kwargs):
"""
Creates an ipwidgets IntSlider with continuous update
turned off
"""
return ipywidgets.IntSlider(value, continuous_update=False, **kwargs)
def plot_measurements(xs, ys=None, dt=None, color='k', lw=1, label='Measurements',
lines=False, **kwargs):
""" Helper function to give a consistant way to display
measurements in the book.
"""
if ys is None and dt is not None:
ys = xs
xs = np.arange(0, len(ys)*dt, dt)
plt.autoscale(tight=False)
if lines:
if ys is not None:
return plt.plot(xs, ys, color=color, lw=lw, ls='--', label=label, **kwargs)
else:
return plt.plot(xs, color=color, lw=lw, ls='--', label=label, **kwargs)
else:
if ys is not None:
return plt.scatter(xs, ys, edgecolor=color, facecolor='none',
lw=2, label=label, **kwargs),
else:
return plt.scatter(range(len(xs)), xs, edgecolor=color, facecolor='none',
lw=2, label=label, **kwargs),
def plot_residual_limits(Ps, stds=1.):
""" plots standand deviation given in Ps as a yellow shaded region. One std
by default, use stds for a different choice (e.g. stds=3 for 3 standard
deviations.
"""
std = np.sqrt(Ps) * stds
plt.plot(-std, color='k', ls=':', lw=2)
plt.plot(std, color='k', ls=':', lw=2)
plt.fill_between(range(len(std)), -std, std,
facecolor='#ffff00', alpha=0.3)
def plot_track(xs, ys=None, dt=None, label='Track', c='k', lw=2, **kwargs):
if ys is None and dt is not None:
ys = xs
xs = np.arange(0, len(ys)*dt, dt)
if ys is not None:
return plt.plot(xs, ys, color=c, lw=lw, ls=':', label=label, **kwargs)
else:
return plt.plot(xs, color=c, lw=lw, ls=':', label=label, **kwargs)
def plot_filter(xs, ys=None, dt=None, c='C0', label='Filter', var=None, **kwargs):
""" plot result of KF with color `c`, optionally displaying the variance
of `xs`. Returns the list of lines generated by plt.plot()"""
if ys is None and dt is not None:
ys = xs
xs = np.arange(0, len(ys) * dt, dt)
if ys is None:
ys = xs
xs = range(len(ys))
lines = plt.plot(xs, ys, color=c, label=label, **kwargs)
if var is None:
return lines
var = np.asarray(var)
std = np.sqrt(var)
std_top = ys+std
std_btm = ys-std
plt.plot(xs, ys+std, linestyle=':', color='k', lw=2)
plt.plot(xs, ys-std, linestyle=':', color='k', lw=2)
plt.fill_between(xs, std_btm, std_top,
facecolor='yellow', alpha=0.2)
return lines
def _blob(x, y, area, colour):
"""
Draws a square-shaped blob with the given area (< 1) at
the given coordinates.
"""
hs = np.sqrt(area) / 2
xcorners = np.array([x - hs, x + hs, x + hs, x - hs])
ycorners = np.array([y - hs, y - hs, y + hs, y + hs])
plt.fill(xcorners, ycorners, colour, edgecolor=colour)
def hinton(W, maxweight=None):
"""
Draws a Hinton diagram for visualizing a weight matrix.
Temporarily disables matplotlib interactive mode if it is on,
otherwise this takes forever.
"""
reenable = False
if plt.isinteractive():
plt.ioff()
plt.clf()
height, width = W.shape
if not maxweight:
maxweight = 2**np.ceil(np.log(np.max(np.abs(W)))/np.log(2))
plt.fill(np.array([0, width, width, 0]),
np.array([0, 0, height, height]),
'gray')
plt.axis('off')
plt.axis('equal')
for x in range(width):
for y in range(height):
_x = x+1
_y = y+1
w = W[y, x]
if w > 0:
_blob(_x - 0.5,
height - _y + 0.5,
min(1, w/maxweight),
'white')
elif w < 0:
_blob(_x - 0.5,
height - _y + 0.5,
min(1, -w/maxweight),
'black')
if reenable:
plt.ion()
if __name__ == "__main__":
plot_errorbar1()
plot_errorbar2()
plot_errorbar3()
plot_hypothesis1()
plot_hypothesis2()
plot_hypothesis3()
plot_hypothesis4()
plot_hypothesis5()
plot_estimate_chart_1()
plot_estimate_chart_2()
plot_estimate_chart_3()
predict_update_chart()
show_residual_chart()
show_residual_chart(True, True)
plt.close('all')
'''p = [0.2245871, 0.06288015, 0.06109133, 0.0581008, 0.09334062, 0.2245871,
0.06288015, 0.06109133, 0.0581008, 0.09334062]*2
bar_plot(p)
plot_measurements(p)'''