affinity prop init

This commit is contained in:
Ritchie
2018-05-17 15:21:30 +02:00
parent 520d2c62de
commit 681ee89634

View File

@@ -0,0 +1,229 @@
{
"cells": [
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from itertools import cycle\n",
"from sklearn.cluster import AffinityPropagation\n",
"import sklearn\n",
"from sklearn.preprocessing import StandardScaler\n",
"import imageio\n",
"from io import BytesIO"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"def make_gif(figures, filename, fps=10, **kwargs):\n",
" images = []\n",
" for fig in figures:\n",
" output = BytesIO()\n",
" fig.savefig(output)\n",
" plt.close(fig) \n",
" output.seek(0)\n",
" images.append(imageio.imread(output))\n",
" imageio.mimsave(filename, images, fps=fps, **kwargs)"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"n = 50\n",
"size = (n, 2)\n",
"\n",
"x = np.random.normal(0, 1, size)\n",
"x = np.append(x, np.random.normal(5, 1, size), axis=0)\n",
"c = ['r' for _ in range(n)] + ['b' for _ in range(n)]\n",
"plt.scatter(x[:, 0], x[:, 1], c=c)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"def similarity(xi, xj):\n",
" return -((xi - xj)**2).sum()"
]
},
{
"cell_type": "code",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def create_matrices():\n",
" S = np.zeros((x.shape[0], x.shape[0]))\n",
" R = np.array(S)\n",
" A = np.array(S)\n",
" # when looking in row i, the value means you should compare to column i - value\n",
" for i in range(x.shape[0]):\n",
" for j in range(x.shape[0]):\n",
" S[i, j] = similarity(x[i], x[j])\n",
" \n",
" return A, R, S\n",
"\n",
"A, R, S = create_matrices()\n",
"S"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"A, R, S = create_matrices()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$ r(i, k) \\leftarrow - \\max\\limits_{k' s.t. k' \\neq k}\\{ a(i, k') + s(i, k') \\}$"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"def update_r(lmda=0.9): \n",
" for i in range(x.shape[0]):\n",
" for k in range(x.shape[0]):\n",
" v = S[i, :] + A[i, :]\n",
" v[k] = -np.inf\n",
" v[i]= -np.inf\n",
" R[i, k] = R[i, k] * lmda + (1 - lmda) * (S[i, k] - np.max(v))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$ a(i, k) \\leftarrow \\min\\{0, r(k,k) + \\sum\\limits_{i' s.t. i' \\notin \\{i, k\\}}{\\max\\{0, r(i', k)\\}}$ "
]
},
{
"cell_type": "code",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"def update_a(lmda=0.9):\n",
" for i in range(x.shape[0]):\n",
" for k in range(x.shape[0]):\n",
" v = np.array(R[:, k])\n",
" if i != k:\n",
" v[i] = -np.inf\n",
" v[k] = - np.inf\n",
" v[v < 0] = 0\n",
"\n",
" A[i, k] = A[i, k] * lmda + (1 - lmda) * min(0, R[k, k] + v.sum())\n",
" else:\n",
" v[k] = -np.inf\n",
" v[v < 0] = 0\n",
" A[k, k] = A[k, k] * lmda + (1 - lmda) * v.sum()\n",
"\n",
"def plot_iteration(A, R):\n",
" fig = plt.figure(figsize=(12, 6))\n",
" sol = A + R\n",
" labels = []\n",
" for i in range(x.shape[0]):\n",
" labels.append(np.argmax(sol[i]))\n",
"\n",
" exemplars = np.unique(labels)\n",
" colors = dict(zip(exemplars, cycle('bgrcmyk')))\n",
" \n",
" for i in range(len(labels)):\n",
" X = x[i][0]\n",
" Y = x[i][1]\n",
" \n",
" if i in exemplars:\n",
" exemplar = i\n",
" edge = 'k'\n",
" ms = 10\n",
" else:\n",
" exemplar = labels[i]\n",
" ms = 3\n",
" edge = None\n",
" plt.plot([X, x[exemplar][0]], [Y, x[exemplar][1]], c=colors[exemplar])\n",
" plt.plot(X, Y, 'o', markersize=ms, markeredgecolor=edge, c=colors[exemplar])\n",
" \n",
"\n",
" plt.title('Number of exemplars: %s' % len(exemplars))\n",
" return fig, labels, exemplars\n",
" \n"
]
},
{
"cell_type": "code",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"A, R, S = create_matrices()\n",
"preference = np.median(S)\n",
"\n",
"np.fill_diagonal(S, preference)\n",
"damping = 0.5\n",
"\n",
"figures = []\n",
"for i in range(20):\n",
" update_r(damping)\n",
" update_a(damping)\n",
" \n",
" if i % 5 == 0:\n",
" fig, labels, exemplars = plot_iteration(A, R)\n",
" figures.append(fig)\n",
" "
]
},
{
"cell_type": "code",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"make_gif(figures, 'test.gif', 2)"
]
}
],
"metadata": {
"git": {
"suppress_outputs": true
},
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}