diff --git a/clustering/affinity_propagation.ipynb b/clustering/affinity_propagation.ipynb new file mode 100644 index 0000000..7139367 --- /dev/null +++ b/clustering/affinity_propagation.ipynb @@ -0,0 +1,229 @@ +{ + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from itertools import cycle\n", + "from sklearn.cluster import AffinityPropagation\n", + "import sklearn\n", + "from sklearn.preprocessing import StandardScaler\n", + "import imageio\n", + "from io import BytesIO" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "outputs": [], + "source": [ + "def make_gif(figures, filename, fps=10, **kwargs):\n", + " images = []\n", + " for fig in figures:\n", + " output = BytesIO()\n", + " fig.savefig(output)\n", + " plt.close(fig) \n", + " output.seek(0)\n", + " images.append(imageio.imread(output))\n", + " imageio.mimsave(filename, images, fps=fps, **kwargs)" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "outputs": [], + "source": [ + "n = 50\n", + "size = (n, 2)\n", + "\n", + "x = np.random.normal(0, 1, size)\n", + "x = np.append(x, np.random.normal(5, 1, size), axis=0)\n", + "c = ['r' for _ in range(n)] + ['b' for _ in range(n)]\n", + "plt.scatter(x[:, 0], x[:, 1], c=c)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "outputs": [], + "source": [ + "def similarity(xi, xj):\n", + " return -((xi - xj)**2).sum()" + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def create_matrices():\n", + " S = np.zeros((x.shape[0], x.shape[0]))\n", + " R = np.array(S)\n", + " A = np.array(S)\n", + " # when looking in row i, the value means you should compare to column i - value\n", + " for i in range(x.shape[0]):\n", + " for j in range(x.shape[0]):\n", + " S[i, j] = similarity(x[i], x[j])\n", + " \n", + " return A, R, S\n", + "\n", + "A, R, S = create_matrices()\n", + "S" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "outputs": [], + "source": [ + "A, R, S = create_matrices()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$ r(i, k) \\leftarrow - \\max\\limits_{k' s.t. k' \\neq k}\\{ a(i, k') + s(i, k') \\}$" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "outputs": [], + "source": [ + "def update_r(lmda=0.9): \n", + " for i in range(x.shape[0]):\n", + " for k in range(x.shape[0]):\n", + " v = S[i, :] + A[i, :]\n", + " v[k] = -np.inf\n", + " v[i]= -np.inf\n", + " R[i, k] = R[i, k] * lmda + (1 - lmda) * (S[i, k] - np.max(v))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$ a(i, k) \\leftarrow \\min\\{0, r(k,k) + \\sum\\limits_{i' s.t. i' \\notin \\{i, k\\}}{\\max\\{0, r(i', k)\\}}$ " + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def update_a(lmda=0.9):\n", + " for i in range(x.shape[0]):\n", + " for k in range(x.shape[0]):\n", + " v = np.array(R[:, k])\n", + " if i != k:\n", + " v[i] = -np.inf\n", + " v[k] = - np.inf\n", + " v[v < 0] = 0\n", + "\n", + " A[i, k] = A[i, k] * lmda + (1 - lmda) * min(0, R[k, k] + v.sum())\n", + " else:\n", + " v[k] = -np.inf\n", + " v[v < 0] = 0\n", + " A[k, k] = A[k, k] * lmda + (1 - lmda) * v.sum()\n", + "\n", + "def plot_iteration(A, R):\n", + " fig = plt.figure(figsize=(12, 6))\n", + " sol = A + R\n", + " labels = []\n", + " for i in range(x.shape[0]):\n", + " labels.append(np.argmax(sol[i]))\n", + "\n", + " exemplars = np.unique(labels)\n", + " colors = dict(zip(exemplars, cycle('bgrcmyk')))\n", + " \n", + " for i in range(len(labels)):\n", + " X = x[i][0]\n", + " Y = x[i][1]\n", + " \n", + " if i in exemplars:\n", + " exemplar = i\n", + " edge = 'k'\n", + " ms = 10\n", + " else:\n", + " exemplar = labels[i]\n", + " ms = 3\n", + " edge = None\n", + " plt.plot([X, x[exemplar][0]], [Y, x[exemplar][1]], c=colors[exemplar])\n", + " plt.plot(X, Y, 'o', markersize=ms, markeredgecolor=edge, c=colors[exemplar])\n", + " \n", + "\n", + " plt.title('Number of exemplars: %s' % len(exemplars))\n", + " return fig, labels, exemplars\n", + " \n" + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "%%time\n", + "A, R, S = create_matrices()\n", + "preference = np.median(S)\n", + "\n", + "np.fill_diagonal(S, preference)\n", + "damping = 0.5\n", + "\n", + "figures = []\n", + "for i in range(20):\n", + " update_r(damping)\n", + " update_a(damping)\n", + " \n", + " if i % 5 == 0:\n", + " fig, labels, exemplars = plot_iteration(A, R)\n", + " figures.append(fig)\n", + " " + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "make_gif(figures, 'test.gif', 2)" + ] + } + ], + "metadata": { + "git": { + "suppress_outputs": true + }, + "kernelspec": { + "display_name": "Python [conda root]", + "language": "python", + "name": "conda-root-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file