affinity prop init
This commit is contained in:
229
clustering/affinity_propagation.ipynb
Normal file
229
clustering/affinity_propagation.ipynb
Normal file
@@ -0,0 +1,229 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from itertools import cycle\n",
|
||||
"from sklearn.cluster import AffinityPropagation\n",
|
||||
"import sklearn\n",
|
||||
"from sklearn.preprocessing import StandardScaler\n",
|
||||
"import imageio\n",
|
||||
"from io import BytesIO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def make_gif(figures, filename, fps=10, **kwargs):\n",
|
||||
" images = []\n",
|
||||
" for fig in figures:\n",
|
||||
" output = BytesIO()\n",
|
||||
" fig.savefig(output)\n",
|
||||
" plt.close(fig) \n",
|
||||
" output.seek(0)\n",
|
||||
" images.append(imageio.imread(output))\n",
|
||||
" imageio.mimsave(filename, images, fps=fps, **kwargs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n = 50\n",
|
||||
"size = (n, 2)\n",
|
||||
"\n",
|
||||
"x = np.random.normal(0, 1, size)\n",
|
||||
"x = np.append(x, np.random.normal(5, 1, size), axis=0)\n",
|
||||
"c = ['r' for _ in range(n)] + ['b' for _ in range(n)]\n",
|
||||
"plt.scatter(x[:, 0], x[:, 1], c=c)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def similarity(xi, xj):\n",
|
||||
" return -((xi - xj)**2).sum()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_matrices():\n",
|
||||
" S = np.zeros((x.shape[0], x.shape[0]))\n",
|
||||
" R = np.array(S)\n",
|
||||
" A = np.array(S)\n",
|
||||
" # when looking in row i, the value means you should compare to column i - value\n",
|
||||
" for i in range(x.shape[0]):\n",
|
||||
" for j in range(x.shape[0]):\n",
|
||||
" S[i, j] = similarity(x[i], x[j])\n",
|
||||
" \n",
|
||||
" return A, R, S\n",
|
||||
"\n",
|
||||
"A, R, S = create_matrices()\n",
|
||||
"S"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"A, R, S = create_matrices()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"$ r(i, k) \\leftarrow - \\max\\limits_{k' s.t. k' \\neq k}\\{ a(i, k') + s(i, k') \\}$"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def update_r(lmda=0.9): \n",
|
||||
" for i in range(x.shape[0]):\n",
|
||||
" for k in range(x.shape[0]):\n",
|
||||
" v = S[i, :] + A[i, :]\n",
|
||||
" v[k] = -np.inf\n",
|
||||
" v[i]= -np.inf\n",
|
||||
" R[i, k] = R[i, k] * lmda + (1 - lmda) * (S[i, k] - np.max(v))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"$ a(i, k) \\leftarrow \\min\\{0, r(k,k) + \\sum\\limits_{i' s.t. i' \\notin \\{i, k\\}}{\\max\\{0, r(i', k)\\}}$ "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def update_a(lmda=0.9):\n",
|
||||
" for i in range(x.shape[0]):\n",
|
||||
" for k in range(x.shape[0]):\n",
|
||||
" v = np.array(R[:, k])\n",
|
||||
" if i != k:\n",
|
||||
" v[i] = -np.inf\n",
|
||||
" v[k] = - np.inf\n",
|
||||
" v[v < 0] = 0\n",
|
||||
"\n",
|
||||
" A[i, k] = A[i, k] * lmda + (1 - lmda) * min(0, R[k, k] + v.sum())\n",
|
||||
" else:\n",
|
||||
" v[k] = -np.inf\n",
|
||||
" v[v < 0] = 0\n",
|
||||
" A[k, k] = A[k, k] * lmda + (1 - lmda) * v.sum()\n",
|
||||
"\n",
|
||||
"def plot_iteration(A, R):\n",
|
||||
" fig = plt.figure(figsize=(12, 6))\n",
|
||||
" sol = A + R\n",
|
||||
" labels = []\n",
|
||||
" for i in range(x.shape[0]):\n",
|
||||
" labels.append(np.argmax(sol[i]))\n",
|
||||
"\n",
|
||||
" exemplars = np.unique(labels)\n",
|
||||
" colors = dict(zip(exemplars, cycle('bgrcmyk')))\n",
|
||||
" \n",
|
||||
" for i in range(len(labels)):\n",
|
||||
" X = x[i][0]\n",
|
||||
" Y = x[i][1]\n",
|
||||
" \n",
|
||||
" if i in exemplars:\n",
|
||||
" exemplar = i\n",
|
||||
" edge = 'k'\n",
|
||||
" ms = 10\n",
|
||||
" else:\n",
|
||||
" exemplar = labels[i]\n",
|
||||
" ms = 3\n",
|
||||
" edge = None\n",
|
||||
" plt.plot([X, x[exemplar][0]], [Y, x[exemplar][1]], c=colors[exemplar])\n",
|
||||
" plt.plot(X, Y, 'o', markersize=ms, markeredgecolor=edge, c=colors[exemplar])\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" plt.title('Number of exemplars: %s' % len(exemplars))\n",
|
||||
" return fig, labels, exemplars\n",
|
||||
" \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"A, R, S = create_matrices()\n",
|
||||
"preference = np.median(S)\n",
|
||||
"\n",
|
||||
"np.fill_diagonal(S, preference)\n",
|
||||
"damping = 0.5\n",
|
||||
"\n",
|
||||
"figures = []\n",
|
||||
"for i in range(20):\n",
|
||||
" update_r(damping)\n",
|
||||
" update_a(damping)\n",
|
||||
" \n",
|
||||
" if i % 5 == 0:\n",
|
||||
" fig, labels, exemplars = plot_iteration(A, R)\n",
|
||||
" figures.append(fig)\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"make_gif(figures, 'test.gif', 2)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"git": {
|
||||
"suppress_outputs": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda root]",
|
||||
"language": "python",
|
||||
"name": "conda-root-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Reference in New Issue
Block a user