From 8f581f8b4b1e9aefec94ed53d9137bf052f0a521 Mon Sep 17 00:00:00 2001 From: ritchie46 Date: Fri, 24 May 2019 13:31:53 +0200 Subject: [PATCH] expectation_maximization --- clustering/expectation_maximization.ipynb | 215 ++++++++++++++++++++++ readme.md | 1 + 2 files changed, 216 insertions(+) create mode 100644 clustering/expectation_maximization.ipynb diff --git a/clustering/expectation_maximization.ipynb b/clustering/expectation_maximization.ipynb new file mode 100644 index 0000000..5e084fc --- /dev/null +++ b/clustering/expectation_maximization.ipynb @@ -0,0 +1,215 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import scipy.stats as stats\n", + "# http://bjlkeng.github.io/posts/the-expectation-maximization-algorithm/" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(36817)\n", + "p_n = stats.binom(40, 0.5)\n", + "k1 = np.random.multivariate_normal([0, 0], np.array([[1, -0.5], [-0.5, 1]]), p_n.rvs())\n", + "\n", + "k2 = np.random.multivariate_normal([9, -1], np.array([[1, 0.1], [0.1, 1]]), p_n.rvs())\n", + "\n", + "k3 = np.random.multivariate_normal([-8, 0.3], np.array([[1, 0.9], [0.9, 1]]), p_n.rvs())" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAF7xJREFUeJzt3X2MXFd5x/Hfk7WDNwt4iWLwehORuE2thmSLrVWqJkWCBq1DlxBjQUSQWioqpVWJ2KLKIiYQVgFkt1YbbVpolRaE/+CllpuYl4E6b5VoiAqM7dROCC7RFhpv1s0iZFNt1om9fvrHzDi7s/fO6525d879fiRrvXdmZ05Gq1+On3vOc8zdBQAIx0VpDwAAkCyCHQACQ7ADQGAIdgAIDMEOAIEh2AEgMAQ7AASGYAeAwBDsABCYVWm86WWXXeZXXnllGm8NAD3r0KFDv3D3dfWel0qwX3nllSoWi2m8NQD0LDP7eSPPoxQDAIEh2AEgMAQ7AASGYAeAwBDsABAYgh0AApPKckegEw4cmdGeg8f1wqkFbRjs146tm7Rt83DawwK6jmBHEA4cmdHOB49p4eyiJGnm1IJ2PnhMkgh35A6lGARhz8HjF0K9YuHsovYcPJ7SiID0EOxdUJguaGz/mEb2jmhs/5gK04W0hxScF04tNHUdCBnB3mGF6YImn5zU7PysXK7Z+VlNPjlJuCdsw2B/U9eBkBHsHTZ1eEpnFs8su3Zm8YymDk+lNKIw7di6Sf2r+5Zd61/dpx1bN6U0IiA93DztsJPzJ5u6jtZUbpCyKgYg2Dtu/cB6zc7PRl5HsrZtHibIAVGK6biJLRNa07dm2bU1fWs0sWUipREBCB0z9g4b3zguqVRrPzl/UusH1mtiy8SF6wCQNIK9C8Y3jhPkALqGUgwABIZgB4DAEOwAEBiCHQAC03awm9kVZvZvZvasmT1jZqzjA4AUJbEq5pykv3D3w2b2OkmHzOwRd/9xAq8NAGhS2zN2d59198Plv/+fpGclsf0PAFKSaI3dzK6UtFnSDyIeu8PMimZWnJubS/JtAQBLJBbsZvZaSf8i6c/d/VfVj7v7A+4+6u6j69atS+ptAQBVEgl2M1utUqh/xd0fTOI1AQCtSWJVjEn6oqRn3f1v2h8SAKAdSayKuVHSH0g6ZmZPla99wt2/k8BrB6UwXdDuH+7WqZdPSZLWXrxWO397J31kACSq7WB39yckWQJjCVphuqBPff9TOnv+7IVrp185rU8+8UlJItwBJIadp11QmC7oE098YlmoV5zzcxyTByBRBHuHVQ6zPu/nY5/DMXkAkkSwd1jUYdbVOCYPQJII9g6rNxtfZas4Jg9Aogj2Dqs1G1978Vp99nc/y41TAIki2Dss7jDr3W/brSduf4JQB5A4zjztMA6zBtBtBHsXcJg1gG6iFNMhhemCxvaPaWTviMb2j6kwXUh7SAByghl7iwrThdjySmXtemWZ4+z8rCafnJTEDlMAnUewt6BecEetXT+zeEZTh6cI9hYdODKjPQeP64VTC9ow2K8dWzdp22bOcwGiUIppQa3gluLXrrPDtDUHjsxo54PHNHNqQS5p5tSCdj54TAeOzKQ9NCCTmLG3oF5wrx9Yr9n52RWPs8O0MdWz8/mXz2nh7OKy5yycXdSeg8eZtQMRmLG3IC6gK9fj1q6zw7S+qNn5qYWVzdMk6YVTC90dHNAjCPYW1Avu8Y3jmrxhUkMDQzKZhgaGNHnDJPX1Buw5eHzF7DzOhsH+Do8G6E2UYlrQyKYj1q63ptFZeP/qPu3YuqnDowF6E8HeIoK7MzYM9msmItzfcMlqXXLxKlbFAA0g2JEpO7Zu0s4Hjy0rx/Sv7tOnb3kLQQ40iGBHplTCmzXrQOsIdmTOts3DBDnQBlbFAEBgmLEDLaDFAbKMYAeaVNlEVbnBW2lxIIlwRyYQ7Oi6Xp/tRm2iosUBsoRgR1eFMNuN20RFiwNkBTdP0VW1Zru9Iq6VAS0OkBUEO7oqhNnujq2b1L+6b9k1WhwgSwh2dFUIs91tm4e1a/t1Gh7sl0kaHuzXru3X9UwpCeFLpMZuZl+S9G5JL7r7tUm8Zi+rdWxe3sW1DOi12S6bqJBlSc3Yvyzp5oReq6dVjs2bnZ+Vyy8cm8dh1iXMdoHOM3dP5oXMrpT07UZm7KOjo14sFhN536wZ2z8WeXrS0MCQHn7fwymMCEAozOyQu4/Wex419oRx3imAtHUt2M3sDjMrmllxbm6uW2/bdfWOzQOATutasLv7A+4+6u6j69at69bbdh3nnQJIGztPE9bIsXkA0ElJLXf8mqS3S7rMzE5I+rS7fzGJ1+5FHJvXG3q9Zw0QJ5Fgd/fbk3gdoFtC6FkDxGFVDHIphJ41QByCHbkUQs8aIA7BjlwKoWcNEIdgRy7RoREhY7kjcqlyg5RVMQgRwV4DXRrDRodGhIpgj1CYLmjXD3bp9CunL1yrdGmURLgDyDRq7FUqbXeXhnrFmcUzmjo8lcKoAKBxBHuVqcNTOrN4JvZxujQCyDqCvUq94KZLI4CsI9ir1Aru1Ret1ktnX9LI3hGN7R/jVCQAmUSwV4lquytJl6y6RO6u06+c5si7Nh04MqMbdz+uq+4q6Mbdj+vAkZm0hwQEhWCvMr5xXJM3TGpoYEgm09DAkHa/bbfWvmatzvm5Zc/lZmrzKs23Zk4tyPVq8y3CHUgOyx0jRLXd3fnvOyOfy83U5tRqvsWaciAZBHuDXn/x6yOXQHIztTlJN9+ipzqwEqWYBhSmC3rp3Esrrq+yVRx516Qkm29R1gGiEewNmDo8pbPnz664/tqLX8su1CYl2XyLnupANEoxDYiro59+eWVpBrUl2XyLnupANIK9AesH1mt2fjbyOpqXRPOtA0dmdJGZFt1XPEZPdeQdpZgGRK1tX9O3hvp6Siq19ahQp6c6wIy9IZU6Oi18syGqti5JfWbatf06VsUg9wh2NdZ3PWptO9IRV0M/706oI1lH90mP3SudPiGtvVy66R5p5La0R1VX7oO90qa30tGRvuvZt2GwXzMR4U5tHYk6uk/61kels+XftdPPl76XMh/uua+xR7XppVVAtnFeKbrisXtfDfWKswul6xmX+xl73FJGWgVkF+eVoitOn2jueobkPthZytibOK8UHbf28lL5Jep6xuW+FMNSRgCRbrpHWl1132Z1f+l6xuV+xs5SRgCRKjdIe3BVjHnEJo9OGx0d9WKx2PX3RT7V6gBJd0gkpgtLI83skLuP1nteIjN2M7tZ0pSkPkn/5O67k3hdoF2VXaqVDU2VDpAVcY8R7mhKxpZGtl1jN7M+SZ+X9C5J10i63cyuafd1gSTU6gBJd0gkJmNLI5OYsV8v6Tl3n5YkM/u6pFsl/TiB1wba0koHSLpD5lwrJZWMLY1MYlXMsKSla4JOlK8tY2Z3mFnRzIpzc3MJvC1QX62DPZI89AOBqJRUTj8vyV8tqRzdV/vn4pZAprQ0Molgt4hrK+7IuvsD7j7q7qPr1q1L4G2B+mrtUmUHK1ZotaSSsaWRSZRiTki6Ysn3l0t6IYHXBVpSWekyc2pBfeWe7ZWvwxErX1gVgwtaLalkbGlkEsH+I0lXm9lVkmYkfUDSBxN4XaBp1atgKj3bF90vzMYJbsRqZ7fpyG2ZWePedinG3c9JulPSQUnPStrn7s+0+7ppK0wXNLZ/TCN7RzS2f0yF6ULaQ0ID4nq1SytXvHAYNlbIWEmlVYm0FHD377j7b7j7r7n755J4zSQ1G9KVVr6z87Ny+YVWvoR79tVb0bL0cZY7YoWR26Rb7pfWXiHJSl9vuT8zM/FGBd9SoJV+67Va+dJqINvierUvfbyi1lJIdqTmWIZKKq0KvglYK/3WaeXbu6JWulRUr3iJW9Y4eMlqSjToacEHeyshHdeyl1a+2bdt87B2bb9Ow+XQ7rPSatzhwf4V56HGLXd0FyUa9LTgSzGt9Fuf2DKxrHwj0cq3lzTaqz3uwI6P/fNTkc9nRyp6RfDB3kpI08o3P6L+J1BZA1+NHanoFcEHe6shPb5xnCDPqR1bNy1bCy+xIxW9JfhglwhpNIczVdHrchHsQLM4UxW9jGAHlmD9eo/owmlFvYxgB8pqnbZEuGdIxk4ryqLg17EDjaLFQI/I2GlFWUSwA2WtnLaEFGTstKIsItiBMk5U6hEZO60oiwh2oIwTlXpEIK11Oym3N08L0wV2lmIZ1q/3iIydVpRF5r7ieNKOGx0d9WKx2PX3rahu5SuV2gxM3jBJuAPILDM75O6j9Z6Xy1JMK618AQTo6D7pvmulycHS16P70h5RInJTillaenFF/yuFfutAjgS8Hj4XM/bqo+7i0G8dyJGA18PnItijSi/V6LcO5EzA6+FzEey1Siwm09DAEDdOgbwJeD18LmrscacoDQ0M6eH3PZzCiAB0TVzDsJvuWV5jl4JZD5+LGfvElgmt6Vuz7BqlFyAHKjdITz8vyV+9QXp0Xyncb7lfWnuFJCt9veX+nr9xKuVkxs5Rd+gW2v5mTK0bpCO3vfonMLkIdqmxU5TYjYp20PY3gwK+QVpLLkoxjaheEjk7P6vJJydVmC6kPTT0CNr+ZlDAN0hrIdjL2I2KRh04MqMbdz+uq+4q6Mbdj+vAkRlJtP3NpJw2DMtNKaaeuCWR7EbFUrXKLRsG+zUTEeK0/U1RIw3DAjxmr61gN7P3S5qU9JuSrnf39Dp7tSluSSS7UbFUrXLLjq2bloW+RNvfTKh1gzTQtgLtlmKelrRd0vcSGEtHFKYLGts/ppG9IxrbPxZbM2dJJBpRq9yybfOwdm2/TsOD/TJJw4P92rX9Om6cZlmgbQXamrG7+7OSZGbJjCZh1e15KzdEJa1Y7cKSyDC1uvww7ufqlVu2bR4myHtJoKtmgq6x17ohGhXYjSyJRO9odflhrZ+j3BKYtZeXNy9FXO9hdUsxZvaomT0d8efWZt7IzO4ws6KZFefm5lofcRO4IZpvrS4/rPVzlFsyoJke6vWeG+iqmbozdnd/ZxJv5O4PSHpAKp2glMRr1sMN0XxrdvlhpfwSVWpZ+nOUW1LUzM3ORp4b6DF7Qa9j54ZovsUtM4y6Xim/xIV6rddDFzVzs7PR547cJn3saWnyVOlrj4e61Gawm9l7zeyEpN+RVDCzg8kMKxnjG8c1ecOkhgaGaM+bQzu2blL/6r5l1+Lq4VHll0Z+Dl3WzM3OQG+MNqLdVTEPSXooobF0BDdE86tSLmlkVUyt3aHDNPPKjmZudrZ7YzTJjUtd3gQV9KoYoNF6eNwyxuHBfn3/rt/rxNDQimZ6qLfTbz3JjUspbIIKusYONKqZsg1S1EwP9Xb6rSe5cSmFTVDM2AE1V7ZByprpod5qv/Uk6/Mp1PoJdqCMZYy4IMmNSylsgqIUAwDVG5muHktu41IKm6AIdgD5FnUu6n9+VfqtDyZzHmoKZ6tSigGQb3E3N3/6cGnDUhK6fLYqM3YA+RbgRiaCHUC+BXguKsEOIN9aubnZTIfJFFBjB5BvzXZ47IHj9Ah2AL0rqR4szdzcrLWTlGAHgDakNXPugZut1NgB9KYke7A0UzOPu6lqF2Wm5k6wA+hNSc2cozYofeuj8eEcdbNVknyxsZ/vAoIdQG9KapliszP/6p2k1rfyOR3u3lgPNXaghso5qHR8zKB2+q0v1crMf+nN1snB5n++w5ixAzGWnoPqkmZOLWjng8d04MhM2kODlFwPlnZn/hnc4ESwAzGizkFdOLuoPQePpzQiRHplXhdq29/9ePO17Xa7L6bQvbEegh2IEXcOaq3zUdFFR/dJ3/iItPDLV68t/FI68GfR4R638qXdmX8K3RvrocYOxIg7B3XDYMSKCHTfY/dKi6+svH7+7MrNQvXWvLfbfbHL3RvrYcYOxOAc1IyrdXOy+rEUzh1NEzN2IAbnoGZc3JFzlceW6oHdokki2IEaOAc1w266p1Rjry7HXLR65Y3LFM4dTROlGAC9aeQ26dbPS/2Xvnqt/1Jp2xdW1rtrrVzJeAveVjBjB9C7Gr1pGdeaV8p8C95WEOwA8iHqfwL3XZv5Fryt6LlSTGG6oLH9YxrZO6Kx/WMqTBfSHhKAXhXoTdWeCvbCdEGTT05qdn5WLtfs/Kwmn5wk3AG0JoPtAJLQVrCb2R4z+4mZHTWzh8wsphtOMqYOT+nM4pll184sntHU4alOvi2AUGWwHUAS2p2xPyLpWncfkfRfkna2P6R4J+dPNnUdAGrKYDuAJLR189TdH17y7X9Iel97w6lt/cB6zc7PRl4HgKZUn5e6/YGeD/SKJGvsH5b03bgHzewOMyuaWXFubq6lN5jYMqE1fWuWXVvTt0YTWyZaej0AOdXsqUk9xty99hPMHpUUNSW+292/UX7O3ZJGJW33ei8oaXR01IvFYgvDLd1AnTo8pZPzJ7V+YL0mtkxofON4S68FIKfuuzZmJ+oV0see7v54GmRmh9x9tN7z6pZi3P2ddd7oQ5LeLemmRkK9XeMbxwlyAO0JdJljRburYm6W9HFJ73H3l5IZEgB0WKDLHCvarbH/naTXSXrEzJ4ys39IYEwA0FmBLnOsaHdVzK8nNRAA6Jq43jGBrIqhVwyAfMrYqUdJ6qmWAgCA+gh2AAgMwQ4AgSHYAeRLgCcmVePmKYAwVfeCCfjEpGoEO4DwVHrBVAf4qv4gT0yqRrADCM9j90YHePW1ikBaCVRQYwcQnmaDOpBWAhUEO4DwxAV1/6VBtxKoINgBhCeuF8y7/jLIE5OqUWMHEJ56vWACC/JqBDuAMAXcC6YeSjEAEBiCHQACQ7ADQGAIdgAIDMEOAIEh2AEgMAQ7AASGYAeAwBDsAMKSg4M06mHnKYBwxPVhl3K1C5UZO4BwxPVhf+zedMaTEoIdQDji+rAHdpBGPQQ7gHDE9WEP7CCNegh2AOGI68Me2EEa9RDsAMIxclsuDtKop61VMWb2GUm3Sjov6UVJf+TuLyQxMABoSY77sFe0O2Pf4+4j7v5WSd+WlK9/7wBABrUV7O7+qyXfDkjy9oYDAGhX2xuUzOxzkv5Q0mlJ72h7RACAttSdsZvZo2b2dMSfWyXJ3e929yskfUXSnTVe5w4zK5pZcW5uLrn/AgDAMuaeTPXEzN4sqeDu19Z77ujoqBeLxUTeFwDywswOuftovee1VWM3s6uXfPseST9p5/UAAO1rt8a+28w2qbTc8eeS/rT9IQEA2pFYKaapNzWbU+l/BO24TNIvEhhOqPh86uMzqo3Pp75uf0Zvdvd19Z6USrAnwcyKjdSa8orPpz4+o9r4fOrL6mdESwEACAzBDgCB6eVgfyDtAWQcn099fEa18fnUl8nPqGdr7ACAaL08YwcAROipYDez95vZM2Z23sxGqx7baWbPmdlxM9ua1hizxMwmzWzGzJ4q//n9tMeUBWZ2c/n35Dkzuyvt8WSRmf3MzI6Vf29yv03czL5kZi+a2dNLrl1qZo+Y2U/LX9+Q5hiX6qlgl/S0pO2Svrf0opldI+kDkt4i6WZJXzCzvu4PL5Puc/e3lv98J+3BpK38e/F5Se+SdI2k28u/P1jpHeXfm8wt50vBl1XKlqXukvSYu18t6bHy95nQU8Hu7s+6+/GIh26V9HV3f9nd/1vSc5Ku7+7o0COul/Scu0+7+yuSvq7S7w8Qy92/J+mXVZdvlbS3/Pe9krZ1dVA19FSw1zAs6fkl358oX4N0p5kdLf9TMjP/VEwRvyuNcUkPm9khM7sj7cFk1JvcfVaSyl/fmPJ4Lmi7H3vSzOxRSesjHrrb3b8R92MR13Kx3KfW5yXp7yV9RqXP4jOS/lrSh7s3ukzK7e9Kk2509xfM7I2SHjGzn5RnregBmQt2d39nCz92QtIVS76/XFIuzl5t9PMys39U6fjCvMvt70ozKmcXu/uLZvaQSiUsgn25/zWzIXefNbMhlc59zoRQSjHflPQBM3uNmV0l6WpJP0x5TKkr/7JVvFelm8959yNJV5vZVWZ2sUo33b+Z8pgyxcwGzOx1lb9LGhO/O1G+KelD5b9/SFJcRaHrMjdjr8XM3ivpbyWtk1Qws6fcfau7P2Nm+yT9WNI5SR9x98U0x5oRf2Vmb1Wp1PAzSX+S7nDS5+7nzOxOSQcl9Un6krs/k/KwsuZNkh4yM6mUEV91939Nd0jpMrOvSXq7pMvM7ISkT0vaLWmfmf2xpP+R9P70RrgcO08BIDChlGIAAGUEOwAEhmAHgMAQ7AAQGIIdAAJDsANAYAh2AAgMwQ4Agfl/nEMGyDyUZFgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.scatter(k1[:, 0], k1[:, 1])\n", + "plt.scatter(k2[:, 0], k2[:, 1])\n", + "plt.scatter(k3[:, 0], k3[:, 1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Gaussian mixture\n", + "Assume there is $\\pi$ probability of a data point coming from one of $K$ clusters.\n", + "\n", + "$$ z_i \\sim Categorical(\\pi) $$\n", + "\n", + "$$ x_i|z_i \\sim N(\\mu_k, \\sigma_k^2) $$\n", + "$$ p(x_i|\\theta) = \\sum_{k=1}^Kp(z_i=k) \\cdot p(x_i|z_i=k, \\mu_k, \\sigma_k^2)$$\n", + "\n", + "## maximization step\n", + "$$ \\pi_k = \\frac{1}{N}\\sum_i p(z_i = k| x_i, \\theta)$$\n", + "$$ \\mu_k = \\frac{\\sum_i p(z_i = k| x_i, \\theta) x_i}{\\sum_i p(z_i = k| x_i, \\theta)}$$\n", + "$$ cov_k = \\frac{\\sum_i p(z_i = k| x_i, \\theta) \\cdot (x_i - \\mu_k)(y_i - \\mu_k)}{\\sum_i p(z_i = k| x_i, \\theta)}$$" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "data = np.vstack([k1, k2, k3])" + ] + }, + { + "cell_type": "code", + "execution_count": 493, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.33846154 0.36923077 0.29230769]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 493, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAF7xJREFUeJzt3X2MXFd5x/Hfk7WDNwt4iWLwehORuE2thmSLrVWqJkWCBq1DlxBjQUSQWioqpVWJ2KLKIiYQVgFkt1YbbVpolRaE/+CllpuYl4E6b5VoiAqM7dROCC7RFhpv1s0iZFNt1om9fvrHzDi7s/fO6525d879fiRrvXdmZ05Gq1+On3vOc8zdBQAIx0VpDwAAkCyCHQACQ7ADQGAIdgAIDMEOAIEh2AEgMAQ7AASGYAeAwBDsABCYVWm86WWXXeZXXnllGm8NAD3r0KFDv3D3dfWel0qwX3nllSoWi2m8NQD0LDP7eSPPoxQDAIEh2AEgMAQ7AASGYAeAwBDsABAYgh0AApPKckegEw4cmdGeg8f1wqkFbRjs146tm7Rt83DawwK6jmBHEA4cmdHOB49p4eyiJGnm1IJ2PnhMkgh35A6lGARhz8HjF0K9YuHsovYcPJ7SiID0EOxdUJguaGz/mEb2jmhs/5gK04W0hxScF04tNHUdCBnB3mGF6YImn5zU7PysXK7Z+VlNPjlJuCdsw2B/U9eBkBHsHTZ1eEpnFs8su3Zm8YymDk+lNKIw7di6Sf2r+5Zd61/dpx1bN6U0IiA93DztsJPzJ5u6jtZUbpCyKgYg2Dtu/cB6zc7PRl5HsrZtHibIAVGK6biJLRNa07dm2bU1fWs0sWUipREBCB0z9g4b3zguqVRrPzl/UusH1mtiy8SF6wCQNIK9C8Y3jhPkALqGUgwABIZgB4DAEOwAEBiCHQAC03awm9kVZvZvZvasmT1jZqzjA4AUJbEq5pykv3D3w2b2OkmHzOwRd/9xAq8NAGhS2zN2d59198Plv/+fpGclsf0PAFKSaI3dzK6UtFnSDyIeu8PMimZWnJubS/JtAQBLJBbsZvZaSf8i6c/d/VfVj7v7A+4+6u6j69atS+ptAQBVEgl2M1utUqh/xd0fTOI1AQCtSWJVjEn6oqRn3f1v2h8SAKAdSayKuVHSH0g6ZmZPla99wt2/k8BrB6UwXdDuH+7WqZdPSZLWXrxWO397J31kACSq7WB39yckWQJjCVphuqBPff9TOnv+7IVrp185rU8+8UlJItwBJIadp11QmC7oE098YlmoV5zzcxyTByBRBHuHVQ6zPu/nY5/DMXkAkkSwd1jUYdbVOCYPQJII9g6rNxtfZas4Jg9Aogj2Dqs1G1978Vp99nc/y41TAIki2Dss7jDr3W/brSduf4JQB5A4zjztMA6zBtBtBHsXcJg1gG6iFNMhhemCxvaPaWTviMb2j6kwXUh7SAByghl7iwrThdjySmXtemWZ4+z8rCafnJTEDlMAnUewt6BecEetXT+zeEZTh6cI9hYdODKjPQeP64VTC9ow2K8dWzdp22bOcwGiUIppQa3gluLXrrPDtDUHjsxo54PHNHNqQS5p5tSCdj54TAeOzKQ9NCCTmLG3oF5wrx9Yr9n52RWPs8O0MdWz8/mXz2nh7OKy5yycXdSeg8eZtQMRmLG3IC6gK9fj1q6zw7S+qNn5qYWVzdMk6YVTC90dHNAjCPYW1Avu8Y3jmrxhUkMDQzKZhgaGNHnDJPX1Buw5eHzF7DzOhsH+Do8G6E2UYlrQyKYj1q63ptFZeP/qPu3YuqnDowF6E8HeIoK7MzYM9msmItzfcMlqXXLxKlbFAA0g2JEpO7Zu0s4Hjy0rx/Sv7tOnb3kLQQ40iGBHplTCmzXrQOsIdmTOts3DBDnQBlbFAEBgmLEDLaDFAbKMYAeaVNlEVbnBW2lxIIlwRyYQ7Oi6Xp/tRm2iosUBsoRgR1eFMNuN20RFiwNkBTdP0VW1Zru9Iq6VAS0OkBUEO7oqhNnujq2b1L+6b9k1WhwgSwh2dFUIs91tm4e1a/t1Gh7sl0kaHuzXru3X9UwpCeFLpMZuZl+S9G5JL7r7tUm8Zi+rdWxe3sW1DOi12S6bqJBlSc3Yvyzp5oReq6dVjs2bnZ+Vyy8cm8dh1iXMdoHOM3dP5oXMrpT07UZm7KOjo14sFhN536wZ2z8WeXrS0MCQHn7fwymMCEAozOyQu4/Wex419oRx3imAtHUt2M3sDjMrmllxbm6uW2/bdfWOzQOATutasLv7A+4+6u6j69at69bbdh3nnQJIGztPE9bIsXkA0ElJLXf8mqS3S7rMzE5I+rS7fzGJ1+5FHJvXG3q9Zw0QJ5Fgd/fbk3gdoFtC6FkDxGFVDHIphJ41QByCHbkUQs8aIA7BjlwKoWcNEIdgRy7RoREhY7kjcqlyg5RVMQgRwV4DXRrDRodGhIpgj1CYLmjXD3bp9CunL1yrdGmURLgDyDRq7FUqbXeXhnrFmcUzmjo8lcKoAKBxBHuVqcNTOrN4JvZxujQCyDqCvUq94KZLI4CsI9ir1Aru1Ret1ktnX9LI3hGN7R/jVCQAmUSwV4lquytJl6y6RO6u06+c5si7Nh04MqMbdz+uq+4q6Mbdj+vAkZm0hwQEhWCvMr5xXJM3TGpoYEgm09DAkHa/bbfWvmatzvm5Zc/lZmrzKs23Zk4tyPVq8y3CHUgOyx0jRLXd3fnvOyOfy83U5tRqvsWaciAZBHuDXn/x6yOXQHIztTlJN9+ipzqwEqWYBhSmC3rp3Esrrq+yVRx516Qkm29R1gGiEewNmDo8pbPnz664/tqLX8su1CYl2XyLnupANEoxDYiro59+eWVpBrUl2XyLnupANIK9AesH1mt2fjbyOpqXRPOtA0dmdJGZFt1XPEZPdeQdpZgGRK1tX9O3hvp6Siq19ahQp6c6wIy9IZU6Oi18syGqti5JfWbatf06VsUg9wh2NdZ3PWptO9IRV0M/706oI1lH90mP3SudPiGtvVy66R5p5La0R1VX7oO90qa30tGRvuvZt2GwXzMR4U5tHYk6uk/61kels+XftdPPl76XMh/uua+xR7XppVVAtnFeKbrisXtfDfWKswul6xmX+xl73FJGWgVkF+eVoitOn2jueobkPthZytibOK8UHbf28lL5Jep6xuW+FMNSRgCRbrpHWl1132Z1f+l6xuV+xs5SRgCRKjdIe3BVjHnEJo9OGx0d9WKx2PX3RT7V6gBJd0gkpgtLI83skLuP1nteIjN2M7tZ0pSkPkn/5O67k3hdoF2VXaqVDU2VDpAVcY8R7mhKxpZGtl1jN7M+SZ+X9C5J10i63cyuafd1gSTU6gBJd0gkJmNLI5OYsV8v6Tl3n5YkM/u6pFsl/TiB1wba0koHSLpD5lwrJZWMLY1MYlXMsKSla4JOlK8tY2Z3mFnRzIpzc3MJvC1QX62DPZI89AOBqJRUTj8vyV8tqRzdV/vn4pZAprQ0Molgt4hrK+7IuvsD7j7q7qPr1q1L4G2B+mrtUmUHK1ZotaSSsaWRSZRiTki6Ysn3l0t6IYHXBVpSWekyc2pBfeWe7ZWvwxErX1gVgwtaLalkbGlkEsH+I0lXm9lVkmYkfUDSBxN4XaBp1atgKj3bF90vzMYJbsRqZ7fpyG2ZWePedinG3c9JulPSQUnPStrn7s+0+7ppK0wXNLZ/TCN7RzS2f0yF6ULaQ0ID4nq1SytXvHAYNlbIWEmlVYm0FHD377j7b7j7r7n755J4zSQ1G9KVVr6z87Ny+YVWvoR79tVb0bL0cZY7YoWR26Rb7pfWXiHJSl9vuT8zM/FGBd9SoJV+67Va+dJqINvierUvfbyi1lJIdqTmWIZKKq0KvglYK/3WaeXbu6JWulRUr3iJW9Y4eMlqSjToacEHeyshHdeyl1a+2bdt87B2bb9Ow+XQ7rPSatzhwf4V56HGLXd0FyUa9LTgSzGt9Fuf2DKxrHwj0cq3lzTaqz3uwI6P/fNTkc9nRyp6RfDB3kpI08o3P6L+J1BZA1+NHanoFcEHe6shPb5xnCDPqR1bNy1bCy+xIxW9JfhglwhpNIczVdHrchHsQLM4UxW9jGAHlmD9eo/owmlFvYxgB8pqnbZEuGdIxk4ryqLg17EDjaLFQI/I2GlFWUSwA2WtnLaEFGTstKIsItiBMk5U6hEZO60oiwh2oIwTlXpEIK11Oym3N08L0wV2lmIZ1q/3iIydVpRF5r7ieNKOGx0d9WKx2PX3rahu5SuV2gxM3jBJuAPILDM75O6j9Z6Xy1JMK618AQTo6D7pvmulycHS16P70h5RInJTillaenFF/yuFfutAjgS8Hj4XM/bqo+7i0G8dyJGA18PnItijSi/V6LcO5EzA6+FzEey1Siwm09DAEDdOgbwJeD18LmrscacoDQ0M6eH3PZzCiAB0TVzDsJvuWV5jl4JZD5+LGfvElgmt6Vuz7BqlFyAHKjdITz8vyV+9QXp0Xyncb7lfWnuFJCt9veX+nr9xKuVkxs5Rd+gW2v5mTK0bpCO3vfonMLkIdqmxU5TYjYp20PY3gwK+QVpLLkoxjaheEjk7P6vJJydVmC6kPTT0CNr+ZlDAN0hrIdjL2I2KRh04MqMbdz+uq+4q6Mbdj+vAkRlJtP3NpJw2DMtNKaaeuCWR7EbFUrXKLRsG+zUTEeK0/U1RIw3DAjxmr61gN7P3S5qU9JuSrnf39Dp7tSluSSS7UbFUrXLLjq2bloW+RNvfTKh1gzTQtgLtlmKelrRd0vcSGEtHFKYLGts/ppG9IxrbPxZbM2dJJBpRq9yybfOwdm2/TsOD/TJJw4P92rX9Om6cZlmgbQXamrG7+7OSZGbJjCZh1e15KzdEJa1Y7cKSyDC1uvww7ufqlVu2bR4myHtJoKtmgq6x17ohGhXYjSyJRO9odflhrZ+j3BKYtZeXNy9FXO9hdUsxZvaomT0d8efWZt7IzO4ws6KZFefm5lofcRO4IZpvrS4/rPVzlFsyoJke6vWeG+iqmbozdnd/ZxJv5O4PSHpAKp2glMRr1sMN0XxrdvlhpfwSVWpZ+nOUW1LUzM3ORp4b6DF7Qa9j54ZovsUtM4y6Xim/xIV6rddDFzVzs7PR547cJn3saWnyVOlrj4e61Gawm9l7zeyEpN+RVDCzg8kMKxnjG8c1ecOkhgaGaM+bQzu2blL/6r5l1+Lq4VHll0Z+Dl3WzM3OQG+MNqLdVTEPSXooobF0BDdE86tSLmlkVUyt3aHDNPPKjmZudrZ7YzTJjUtd3gQV9KoYoNF6eNwyxuHBfn3/rt/rxNDQimZ6qLfTbz3JjUspbIIKusYONKqZsg1S1EwP9Xb6rSe5cSmFTVDM2AE1V7ZByprpod5qv/Uk6/Mp1PoJdqCMZYy4IMmNSylsgqIUAwDVG5muHktu41IKm6AIdgD5FnUu6n9+VfqtDyZzHmoKZ6tSigGQb3E3N3/6cGnDUhK6fLYqM3YA+RbgRiaCHUC+BXguKsEOIN9aubnZTIfJFFBjB5BvzXZ47IHj9Ah2AL0rqR4szdzcrLWTlGAHgDakNXPugZut1NgB9KYke7A0UzOPu6lqF2Wm5k6wA+hNSc2cozYofeuj8eEcdbNVknyxsZ/vAoIdQG9KapliszP/6p2k1rfyOR3u3lgPNXaghso5qHR8zKB2+q0v1crMf+nN1snB5n++w5ixAzGWnoPqkmZOLWjng8d04MhM2kODlFwPlnZn/hnc4ESwAzGizkFdOLuoPQePpzQiRHplXhdq29/9ePO17Xa7L6bQvbEegh2IEXcOaq3zUdFFR/dJ3/iItPDLV68t/FI68GfR4R638qXdmX8K3RvrocYOxIg7B3XDYMSKCHTfY/dKi6+svH7+7MrNQvXWvLfbfbHL3RvrYcYOxOAc1IyrdXOy+rEUzh1NEzN2IAbnoGZc3JFzlceW6oHdokki2IEaOAc1w266p1Rjry7HXLR65Y3LFM4dTROlGAC9aeQ26dbPS/2Xvnqt/1Jp2xdW1rtrrVzJeAveVjBjB9C7Gr1pGdeaV8p8C95WEOwA8iHqfwL3XZv5Fryt6LlSTGG6oLH9YxrZO6Kx/WMqTBfSHhKAXhXoTdWeCvbCdEGTT05qdn5WLtfs/Kwmn5wk3AG0JoPtAJLQVrCb2R4z+4mZHTWzh8wsphtOMqYOT+nM4pll184sntHU4alOvi2AUGWwHUAS2p2xPyLpWncfkfRfkna2P6R4J+dPNnUdAGrKYDuAJLR189TdH17y7X9Iel97w6lt/cB6zc7PRl4HgKZUn5e6/YGeD/SKJGvsH5b03bgHzewOMyuaWXFubq6lN5jYMqE1fWuWXVvTt0YTWyZaej0AOdXsqUk9xty99hPMHpUUNSW+292/UX7O3ZJGJW33ei8oaXR01IvFYgvDLd1AnTo8pZPzJ7V+YL0mtkxofON4S68FIKfuuzZmJ+oV0see7v54GmRmh9x9tN7z6pZi3P2ddd7oQ5LeLemmRkK9XeMbxwlyAO0JdJljRburYm6W9HFJ73H3l5IZEgB0WKDLHCvarbH/naTXSXrEzJ4ys39IYEwA0FmBLnOsaHdVzK8nNRAA6Jq43jGBrIqhVwyAfMrYqUdJ6qmWAgCA+gh2AAgMwQ4AgSHYAeRLgCcmVePmKYAwVfeCCfjEpGoEO4DwVHrBVAf4qv4gT0yqRrADCM9j90YHePW1ikBaCVRQYwcQnmaDOpBWAhUEO4DwxAV1/6VBtxKoINgBhCeuF8y7/jLIE5OqUWMHEJ56vWACC/JqBDuAMAXcC6YeSjEAEBiCHQACQ7ADQGAIdgAIDMEOAIEh2AEgMAQ7AASGYAeAwBDsAMKSg4M06mHnKYBwxPVhl3K1C5UZO4BwxPVhf+zedMaTEoIdQDji+rAHdpBGPQQ7gHDE9WEP7CCNegh2AOGI68Me2EEa9RDsAMIxclsuDtKop61VMWb2GUm3Sjov6UVJf+TuLyQxMABoSY77sFe0O2Pf4+4j7v5WSd+WlK9/7wBABrUV7O7+qyXfDkjy9oYDAGhX2xuUzOxzkv5Q0mlJ72h7RACAttSdsZvZo2b2dMSfWyXJ3e929yskfUXSnTVe5w4zK5pZcW5uLrn/AgDAMuaeTPXEzN4sqeDu19Z77ujoqBeLxUTeFwDywswOuftovee1VWM3s6uXfPseST9p5/UAAO1rt8a+28w2qbTc8eeS/rT9IQEA2pFYKaapNzWbU+l/BO24TNIvEhhOqPh86uMzqo3Pp75uf0Zvdvd19Z6USrAnwcyKjdSa8orPpz4+o9r4fOrL6mdESwEACAzBDgCB6eVgfyDtAWQcn099fEa18fnUl8nPqGdr7ACAaL08YwcAROipYDez95vZM2Z23sxGqx7baWbPmdlxM9ua1hizxMwmzWzGzJ4q//n9tMeUBWZ2c/n35Dkzuyvt8WSRmf3MzI6Vf29yv03czL5kZi+a2dNLrl1qZo+Y2U/LX9+Q5hiX6qlgl/S0pO2Svrf0opldI+kDkt4i6WZJXzCzvu4PL5Puc/e3lv98J+3BpK38e/F5Se+SdI2k28u/P1jpHeXfm8wt50vBl1XKlqXukvSYu18t6bHy95nQU8Hu7s+6+/GIh26V9HV3f9nd/1vSc5Ku7+7o0COul/Scu0+7+yuSvq7S7w8Qy92/J+mXVZdvlbS3/Pe9krZ1dVA19FSw1zAs6fkl358oX4N0p5kdLf9TMjP/VEwRvyuNcUkPm9khM7sj7cFk1JvcfVaSyl/fmPJ4Lmi7H3vSzOxRSesjHrrb3b8R92MR13Kx3KfW5yXp7yV9RqXP4jOS/lrSh7s3ukzK7e9Kk2509xfM7I2SHjGzn5RnregBmQt2d39nCz92QtIVS76/XFIuzl5t9PMys39U6fjCvMvt70ozKmcXu/uLZvaQSiUsgn25/zWzIXefNbMhlc59zoRQSjHflPQBM3uNmV0l6WpJP0x5TKkr/7JVvFelm8959yNJV5vZVWZ2sUo33b+Z8pgyxcwGzOx1lb9LGhO/O1G+KelD5b9/SFJcRaHrMjdjr8XM3ivpbyWtk1Qws6fcfau7P2Nm+yT9WNI5SR9x98U0x5oRf2Vmb1Wp1PAzSX+S7nDS5+7nzOxOSQcl9Un6krs/k/KwsuZNkh4yM6mUEV91939Nd0jpMrOvSXq7pMvM7ISkT0vaLWmfmf2xpP+R9P70RrgcO08BIDChlGIAAGUEOwAEhmAHgMAQ7AAQGIIdAAJDsANAYAh2AAgMwQ4Agfl/nEMGyDyUZFgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "def weighted_covariance(x, y, likelihoods, mu_k):\n", + " return np.sum(likelihoods * (x - mu_k[0]) * (y - mu_k[1])) / likelihoods.sum()\n", + "\n", + "class EM:\n", + " def __init__(self, k):\n", + " self.k = k\n", + " self.z = np.arange(k)\n", + " self.mu = np.random.choice(data.flatten(), k * 2).reshape((k, 2)) # 2 dimensional\n", + " self.cov = np.stack([np.array([[1., 0.], [0., 1.]]) for _ in range(k)])\n", + " self.pi = np.ones(k) / k\n", + " self.log_likelihoods = None\n", + " \n", + " def expectation_step(self, x):\n", + " for z_i in self.z:\n", + " mu_i = self.mu[z_i]\n", + " cov_i = self.cov[z_i]\n", + " self.log_likelihoods[z_i, :] = self.pi[z_i] + stats.multivariate_normal(mu_i, cov_i).logpdf(x)\n", + " # normalize by marginalizing K\n", + " self.log_likelihoods = self.log_likelihoods - np.log(np.exp(self.log_likelihoods).sum(0))\n", + " \n", + " def maximization_step(self, x):\n", + " for z_i in self.z:\n", + " likelihoods = np.exp(self.log_likelihoods[z_i])\n", + " # weighted average\n", + " self.mu[z_i] = ((x * likelihoods[:, None]) / likelihoods.sum()).sum(0)\n", + " \n", + " # weighted variance\n", + " cov = weighted_covariance(x[:, 0], x[:, 1], likelihoods, self.mu[z_i])\n", + " \n", + " cov = np.array([[weighted_covariance(x[:, 0], x[:, 0], likelihoods, self.mu[z_i]), cov],\n", + " [cov, weighted_covariance(x[:, 1], x[:, 1], likelihoods, self.mu[z_i])]])\n", + " \n", + " if np.all(np.linalg.eigvals(cov) > 0) and (1e-6 < np.linalg.det(cov)):\n", + " self.cov[z_i] = cov\n", + " \n", + " # weighted pi\n", + " self.pi[z_i] = likelihoods.sum() / x.shape[0]\n", + " \n", + " def fit(self, x):\n", + " self.log_likelihoods = np.zeros((self.k, x.shape[0]))\n", + " last_log_likelihood = np.inf\n", + " while np.abs(self.log_likelihoods.sum() - last_log_likelihood) > 0.01:\n", + " last_log_likelihood = self.log_likelihoods.sum()\n", + " self.expectation_step(x)\n", + " self.maximization_step(x)\n", + " \n", + " def predict(self, x):\n", + " log_likelihoods = np.zeros((self.k, x.shape[0]))\n", + " for z_i in self.z:\n", + " mu_i = self.mu[z_i]\n", + " cov_i = self.cov[z_i]\n", + " log_likelihoods[z_i, :] = self.pi[z_i] + stats.multivariate_normal(mu_i, cov_i).logpdf(x)\n", + " return log_likelihoods.argmax(0)\n", + " \n", + "m = EM(3)\n", + "m.fit(data)\n", + "y = m.predict(data)\n", + "\n", + "mask = y == 0\n", + "plt.scatter(data[mask][:, 0], data[mask][:, 1])\n", + "mask = y == 1\n", + "plt.scatter(data[mask][:, 0], data[mask][:, 1])\n", + "mask = y == 2\n", + "plt.scatter(data[mask][:, 0], data[mask][:, 1])\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/readme.md b/readme.md index e568e4c..8a9c5d0 100644 --- a/readme.md +++ b/readme.md @@ -9,3 +9,4 @@ * [support vector machine](https://www.ritchievink.com/blog/2017/11/27/implementing-a-support-vector-machine-in-scala/) * [neural network](https://ritchievink.com/blog/2017/07/10/programming-a-neural-network-from-scratch/) * [arima models](https://www.ritchievink.com/blog/2018/09/26/algorithm-breakdown-ar-ma-and-arima-models./) +* [expectation_maximization](/clustering/expectation_maximization.ipynb)