diff --git a/bayesian/normalizing_flows/vae.ipynb b/bayesian/normalizing_flows/vae.ipynb new file mode 100644 index 0000000..af39412 --- /dev/null +++ b/bayesian/normalizing_flows/vae.ipynb @@ -0,0 +1,689 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torchvision import datasets, transforms\n", + "import torch.nn.functional as F\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 128\n", + "dl = torch.utils.data.DataLoader(\n", + " datasets.FashionMNIST('.', train=True, transform=transforms.ToTensor(), download=True),\n", + "batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 119, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAO8UlEQVR4nO3db4hV953H8c83OjrqDGiizoo1tpZAEhZqF5GFlCVL2ZLmiWmgmwrZuBB2+sCAJX2QkD6oD8PSPylkaZhuQu3STSlUiQ/CWpFC0icSIzYxcXcTxW1GxVESU2s0zozfPpjjMpq5v994z7nnnJnv+wXDzNzvnHu/Xucz5977vef8zN0FYP67rekGANSDsANBEHYgCMIOBEHYgSAW1nljZsZL/zPo6+tL1oeGhpL1s2fPdqyNj4931VMdli1blqwvXbo0WT9//nyyHnXS5O420+Wlwm5mD0j6iaQFkv7d3Z8tc31RrVy5MlnfsWNHsv7cc891rJ06daqrnuqwcePGUvWRkZFkvc1/6JrQ9cN4M1sg6d8kfV3SvZK2mtm9VTUGoFplnrNvlvS+u59w96uSfiVpSzVtAahambCvlfTBtO9Hi8tuYGbDZnbIzA6VuC0AJZV5zj7TiwCfeUXE3UckjUi8QAc0qcyefVTSumnff07S6XLtAOiVMmF/Q9JdZvYFM1sk6VuS9lbTFoCqdf0w3t0nzOwJSfs0NXp7yd3fqayzeWRwcDBZf+yxx5L1Rx99NFl/5JFHOtY+/vjj5La58dTVq1eT9YGBgWR90aJFHWu5keO+ffuS9cnJyWT9hRdeSNajKTVnd/dXJb1aUS8Aeoi3ywJBEHYgCMIOBEHYgSAIOxAEYQeCqPV49qguXryYrOdm4U899VSyvnPnzo61devWdaxJ+WPKFy5M/4pcuXIlWU/N6Q8ePJjcds+ePcl6bsaPG7FnB4Ig7EAQhB0IgrADQRB2IAjCDgTB6K0FUoeBStJHH32UrKfOLvvkk08mt52YmEjWc6O3S5cuJevvvvtux9rzzz+f3HbDhg3J+tjYWLKOG7FnB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgmLO3QO4Q2FWrViXrJ0+e7Fjbvn17ctv169cn66tXr07Wjx8/nqynZuG5pahzM36zGVcmRgfs2YEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCObsLZA7pjwnt/Rxyrlz55L106dPJ+u5U1HfeeedHWu5JZfdvVQdNyoVdjM7KemipElJE+6+qYqmAFSvij3737v7+QquB0AP8ZwdCKJs2F3Sb83sTTMbnukHzGzYzA6Z2aGStwWghLIP4+9z99NmtlrSfjP7b3d/bfoPuPuIpBFJMjNeUQEaUmrP7u6ni89jkvZI2lxFUwCq13XYzWyZmQ1e/1rS1yQdraoxANUq8zB+SNKe4pjihZL+093/q5KugrnttvTf3Nw8OTWvXrBgQXLb5cuXJ+u9lDsePffvzh3vjht1fW+5+wlJX6qwFwA9xOgNCIKwA0EQdiAIwg4EQdiBIJhdtMDg4GCyvnjx4mT98uXLHWu50du1a9eS9dz2ubFhL6+7v7+/69uOiD07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTBnL0Fyi5NnKrnZtm5687NunP11PWPj4+Xum4Ocb017NmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAgGlS2Qmyd/8sknyXpqlp677twcPresck6ZZZU//fTTUreNG7FnB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgmLO3QJlzr0vp47qbPC98Tl9fX7Kem7MPDQ1V2c68l/2fNLOXzGzMzI5Ou+x2M9tvZu8Vn1f0tk0AZc3mz/bPJT1w02VPSzrg7ndJOlB8D6DFsmF399ckfXjTxVsk7Sq+3iXpoYr7AlCxbp+zD7n7GUly9zNmtrrTD5rZsKThLm8HQEV6/gKdu49IGpEkM+v+qAgApXT7UutZM1sjScXnsepaAtAL3YZ9r6RtxdfbJL1STTsAeiX7MN7MXpZ0v6SVZjYq6fuSnpX0azN7XNIfJX2zl03OdXfccUeyXvbc7qljxns5J5+N1Jw/d9731LrzkjQwMJCsL1mypOvrno+yYXf3rR1KX624FwA9xNtlgSAIOxAEYQeCIOxAEIQdCIJDXGtw5cqVZD13KGeZ0zHnlL3usks+p+RGkhcuXEjWI47XUtizA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQzNlrkJtl5+bJ81Xufunv76+pkxjYswNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEMzZa1B2jp5bdrmXp4tu8rZz1z0xMdH19rl/13zEnh0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgmDOXoOlS5cm67njunMz4dQcf3JyMrltbpbdy2PtyyxFPZv64sWLO9YinlM+u2c3s5fMbMzMjk67bKeZnTKzI8XHg71tE0BZs3kY/3NJD8xw+Y/dfWPx8Wq1bQGoWjbs7v6apA9r6AVAD5V5ge4JM3ureJi/otMPmdmwmR0ys0MlbgtASd2G/aeSvihpo6Qzkn7Y6QfdfcTdN7n7pi5vC0AFugq7u59190l3vybpZ5I2V9sWgKp1FXYzWzPt229IOtrpZwG0Q3bObmYvS7pf0kozG5X0fUn3m9lGSS7ppKRv97DHOS83Ty67xnmZNdZzt92ksr318lj7uSgbdnffOsPFL/agFwA9xJ8+IAjCDgRB2IEgCDsQBGEHguAQ1xq0eUnm3Niu7PgrtX3ZpawXLkz/+vb19SXr0bBnB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgmLPXIHeoZe50z2VOJV32MM8yh8/mti/bW+49ACtWdDxbmi5cuFDqtuci9uxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARz9hrkjqvOzZvLzKPLLovcpNz7D3K9L1u2rMp25jz27EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBHP2GixatChZz83Cc+dPb/OsPGViYqLU9levXk3W27wcdROye3YzW2dmvzOzY2b2jpntKC6/3cz2m9l7xefOZwoA0LjZPIyfkPRdd79H0t9K2m5m90p6WtIBd79L0oHiewAtlQ27u59x98PF1xclHZO0VtIWSbuKH9sl6aFeNQmgvFt6zm5mn5f0ZUkHJQ25+xlp6g+Cma3usM2wpOFybQIoa9ZhN7MBSb+R9B13/9NsX/xw9xFJI8V1zM1XkoB5YFajNzPr01TQf+nuu4uLz5rZmqK+RtJYb1oEUIXsnt2mduEvSjrm7j+aVtoraZukZ4vPr/Skw3mgv7+/1Pa50VrqVNNlT9fcpNwptMfHx5P15cuXV9nOnDebh/H3SfonSW+b2ZHismc0FfJfm9njkv4o6Zu9aRFAFbJhd/ffS+r0BP2r1bYDoFfm7mM8ALeEsANBEHYgCMIOBEHYgSA4xLUGuTl7bo6emyen3s3Y5jl7rreyc/Z77rmnY+31119Pbjsftfc3AUClCDsQBGEHgiDsQBCEHQiCsANBEHYgCObsNdiwYUOp7XOnkk7No3Oz6l6fpjo1S8/1lpvD5+bs586dS9ajYc8OBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EwZ6/BlStXkvW+vr5kPTePXriw839jbuWeycnJZD03h89JLauc6lvK9zY4OJisHz9+PFmPhj07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgQxm/XZ10n6haS/knRN0oi7/8TMdkr6F0nXDxp+xt1f7VWjc9m+ffuS9bvvvjtZX7FiRbJ++fLlW+7putwcfmJiIlnPvQegjLVr15a67cOHD1fZzpw3mzfVTEj6rrsfNrNBSW+a2f6i9mN3/0Hv2gNQldmsz35G0pni64tmdkxS+k8ugNa5pefsZvZ5SV+WdLC46Akze8vMXjKzGR9rmtmwmR0ys0OlOgVQyqzDbmYDkn4j6Tvu/idJP5X0RUkbNbXn/+FM27n7iLtvcvdNFfQLoEuzCruZ9Wkq6L90992S5O5n3X3S3a9J+pmkzb1rE0BZ2bDb1Mu1L0o65u4/mnb5mmk/9g1JR6tvD0BVLHeqYDP7iqTXJb2tqdGbJD0jaaumHsK7pJOSvl28mJe6rnLnJZ6nlixZkqw//PDDyfqqVas61gYGBpLb5g5hzY3eclLXn7vuDz74IFnfvXt3sn7p0qVkfb5y9xnnqbN5Nf73kmbamJk6MIfwDjogCMIOBEHYgSAIOxAEYQeCIOxAENk5e6U3FnTOnjuMtJf/BytXrkzW169fn6ynZvhSvvcTJ050rI2Ojia3LXPorpS+3+v8va9bpzk7e3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCKLuOfs5Sf837aKVks7X1sCtaWtvbe1LorduVdnbenef8c0RtYb9Mzdudqit56Zra29t7Uuit27V1RsP44EgCDsQRNNhH2n49lPa2ltb+5LorVu19Nboc3YA9Wl6zw6gJoQdCKKRsJvZA2b2P2b2vpk93UQPnZjZSTN728yONL0+XbGG3piZHZ122e1mtt/M3is+p9dzrre3nWZ2qrjvjpjZgw31ts7Mfmdmx8zsHTPbUVze6H2X6KuW+6325+xmtkDS/0r6B0mjkt6QtNXd3621kQ7M7KSkTe7e+BswzOzvJP1Z0i/c/a+Ly/5V0ofu/mzxh3KFuz/Vkt52Svpz08t4F6sVrZm+zLikhyT9sxq87xJ9/aNquN+a2LNvlvS+u59w96uSfiVpSwN9tJ67vybpw5su3iJpV/H1Lk39stSuQ2+t4O5n3P1w8fVFSdeXGW/0vkv0VYsmwr5W0vR1fUbVrvXeXdJvzexNMxtuupkZDF1fZqv4vLrhfm6WXca7TjctM96a+66b5c/LaiLsM50fq03zv/vc/W8kfV3S9uLhKmZnVst412WGZcZbodvlz8tqIuyjktZN+/5zkk430MeM3P108XlM0h61bynqs9dX0C0+jzXcz/9r0zLeMy0zrhbcd00uf95E2N+QdJeZfcHMFkn6lqS9DfTxGWa2rHjhRGa2TNLX1L6lqPdK2lZ8vU3SKw32coO2LOPdaZlxNXzfNb78ubvX/iHpQU29In9c0vea6KFDXxsk/aH4eKfp3iS9rKmHdeOaekT0uKQ7JB2Q9F7x+fYW9fYfmlra+y1NBWtNQ719RVNPDd+SdKT4eLDp+y7RVy33G2+XBYLgHXRAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EMRfANHBpxawvfpYAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.imshow(1 - next(iter(dl))[0][2].view(28, 28).numpy(), cmap='Greys')" + ] + }, + { + "cell_type": "code", + "execution_count": 169, + "metadata": {}, + "outputs": [], + "source": [ + "def middle_size(size_in, size_out):\n", + " \"\"\"\n", + " Make a funnel with regard to the number of weights.\n", + " \"\"\"\n", + " return (size_in - size_out) // 2 + size_out\n", + "\n", + "class VAE(nn.Module):\n", + " def __init__(self, latent_size):\n", + " super().__init__()\n", + " \n", + " self.latent_size = latent_size \n", + " size = middle_size(784, self.latent_size)\n", + " \n", + " # Encode into gaussian paramaters. First part of vector is mu and second part log variance\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(784, size),\n", + " nn.ReLU(),\n", + " nn.Linear(size, size),\n", + " nn.ReLU(),\n", + " nn.Linear(size, latent_size * 2),\n", + " )\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(latent_size, size),\n", + " nn.ReLU(),\n", + " nn.Linear(size, size),\n", + " nn.ReLU(),\n", + " nn.Linear(size, 784),\n", + " nn.Sigmoid()\n", + " )\n", + " \n", + " def reparameterize(self, mu, log_var):\n", + " std = torch.exp(0.5 * log_var)\n", + " eps = torch.randn_like(std) # unit gaussian\n", + " z = mu + eps * std \n", + " return z\n", + "\n", + " def forward(self, x):\n", + " z_theta = self.encoder(x)\n", + " mu = z_theta[:, :self.latent_size]\n", + " log_var = z_theta[:, self.latent_size:]\n", + " z = self.reparameterize(mu, log_var)\n", + " return self.decoder(z), mu, log_var" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CNN" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 171, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAXq0lEQVR4nO2da4zV5bXGnwUDyP0uQ7mDVG5VwBFREUQqgk2DpNFgE8tJGukH29qkH05TP9RPTXNyaNOkp03x1EpNL8FQjqZSj4SgSC/IgIDcQS4DggOCwIAIc1nnw2x6qJ33WdO57D3nvM8vITOzn1n7/+7/5pn/3nu9ay1zdwgh/v/TqdQLEEIUB5ldiEyQ2YXIBJldiEyQ2YXIhLKiHqyszLt06cJ0Gt+jR4+k9sknn9DYTp3437Xa2lqqd+/ePaldvXqVxrLHDAD19fVUj85LazIq0ePu2rUr1Tt37kx19tiuXbtGY6PHHR27oaGhxcfu1q0b1c2sxccG+HNWV1dHY9n/9bq6OjQ0NDS5uFaZ3cwWAPgxgM4A/tPdf8B+v0uXLhg9enRSLy8vp8e7/fbbk9r+/ftpLPtDAQDV1dVUnzBhQlI7fPgwjY0e1+XLl6k+cOBAqjPDRn9IPvjgA6qPHDmS6n369KH6xYsXk9rJkydpbPS4o2Oz83r8+HEaO3bsWKpHf8Cj55T9MTh79iyN3b17d1I7c+ZMUmvxy3gz6wzgPwAsBDAJwONmNqml9yeEaF9a8559BoBD7n7Y3a8B+B2ARW2zLCFEW9Masw8DcONroROF2/4OM1tmZpVmVhm9FxFCtB+tMXtTHwL8w6cO7r7C3SvcvSL6wEUI0X60xuwnAIy44efhAPgnLkKIktEas28BMN7MxphZVwBLALzSNssSQrQ1LX5d7e51ZvZ1AP+NxtTb8+6ezgmgMd3A8ptXrlyhxxwxYkRSi/KagwYNovqJEyeoznK6Dz30EI3t3bs31Q8cOEB1luMHeOrt5ptvprFRrvro0aNUj87rLbfcktRYGhaI90ZEax8yZEhSe+edd2hslEcfN24c1c+fP9/i+48e94ULF1p03Fa9iXb3tQDWtuY+hBDFQdtlhcgEmV2ITJDZhcgEmV2ITJDZhcgEmV2ITLBidpft0aOH33rrrUl9zpw5NH7WrFlJLcpN9u3bl+rbt2+nOiunZLlkADh37hzVhw4dSvWDBw9SvVevXkktKu09duwY1efNm0f1aH8CywlHPQiiEtdI37lzZ1Lr2bMnje3Xrx/VP/74Y6qz0l6A772I7vvVV19Naq+99hrOnj3bZBJfV3YhMkFmFyITZHYhMkFmFyITZHYhMkFmFyITito6pqGhATU1NUk9KpdkHV6j1sCDBw+m+r59+6i+YMGCpPbHP/6Rxt51112tOvaAAQOozs5b//79aWyUYorKSG+66aYW33/U+TYq7b106RLV586dm9TeeOMNGstSxECcTo1adLMOsqw0F+BpPXZOdWUXIhNkdiEyQWYXIhNkdiEyQWYXIhNkdiEyQWYXIhOKnmdnZY2nT5+m8awcNyphXb16NdW/+MUvUp3lZRcuXNiqY0+dOpXqUb6ZTRTdsGEDjZ08eTLVjxw5QvWoxJVNM41y2dGE2agV9datW5NatO6ofLY1U1oB4MMPP2zxsdn+AuXZhRAyuxC5ILMLkQkyuxCZILMLkQkyuxCZILMLkQkdKs/erVs3Gr9///6kNn36dBr7pS99ieosVx0RtUQeM2YM1YcNG0b1qM11eXl5UotGNketpKMeA/fddx/V2cjnTZs20dioZvyb3/wm1dlzescdd7Q4Foj7J0T/J+bPn5/UNm/eTGOZT1hL9VaZ3cyOAqgBUA+gzt0rWnN/Qoj2oy2u7HPdPb0dSAjRIdB7diEyobVmdwCvm9lWM1vW1C+Y2TIzqzSzymKOmhJC/D2tfRl/r7ufNLObAawzs33uvvHGX3D3FQBWAEBZWZncLkSJaNWV3d1PFr6eBrAGwIy2WJQQou1psdnNrKeZ9b7+PYD5AHa11cKEEG1La17GDwGwxsyu389v3P21KIi9b4/G6BaO1SRRffHhw4ep/rnPfY7qbKzyihUraOxtt91G9WjtUd/5v/zlLy0+9vjx46n+5z//mepRz/v169cnteXLl9PYjRs3Un3VqlVUZ/XuLM8NAFVVVVSP9i8cOHCA6sePH09qUY4/Gk+eosVmd/fDAG5vabwQorgo9SZEJsjsQmSCzC5EJsjsQmSCzC5EJhS1xNXMUFaWPmSvXr1o/AMPPJDU2BhbIC4zjdoWs3bP0X3v3r2b6lGZaQRLA0Uttrds2dLi+waAXbv41orFixcntTfffJPGRq2m+/TpQ3XWwjt6ziI9SjlGba4PHTqU1C5cuEBjr1y5ktRYC2td2YXIBJldiEyQ2YXIBJldiEyQ2YXIBJldiEyQ2YXIhKLm2bt164axY8cm9SjXzVoLV1ZW0tioLLBz585UZ/nmqCSRjQ4GgBEjRlB91KhRVGePLWpD/frrr1P9zjvvpPq0adOovmbNmqTG9k0AcQ5/woQJVJ8zZ05Si/LkK1eupPqSJUuoHj3nrJw7OqdvvfVWUmP7WHRlFyITZHYhMkFmFyITZHYhMkFmFyITZHYhMkFmFyITippnB3hOOKpnZ3XfQ4YMobF1dXVUj47NxujW1NTQ2EcffZTqUb17dXU11Y8cOZLUoscdtZI+ePAg1aORzkuXLk1q9fX1NDYaXTx48GCqsz0C0d4Gdk6B+DmL6tnZ8xLtAWCxrFW7ruxCZILMLkQmyOxCZILMLkQmyOxCZILMLkQmyOxCZEJR8+x1dXU0Zzxjxgwaf/78+RZpQNz//Pbb+UBapn//+9+nsVHONqrrvnr1KtVZH4CRI0fS2Llz51L9xRdfpDobyQwADz74YFKbPHkyjV24cCHVN2zYQHXW+3348OE0try8nOozZ86kOhujDfA8fzRm+6WXXkpqreobb2bPm9lpM9t1w20DzGydmR0sfO0f3Y8QorQ052X8CwAWfOq27wBY7+7jAawv/CyE6MCEZnf3jQA+3Q9qEYDrfXtWAnikjdclhGhjWvqefYi7nwIAdz9lZsk3xGa2DMAygPfHEkK0L+3+aby7r3D3CneviJo6CiHaj5aavdrMhgJA4evptluSEKI9aKnZXwFwvXZxKYCX22Y5Qoj2wlj9KwCY2W8B3A9gEIBqAN8D8F8AVgEYCaAKwKPunm7qXqCsrMzZTO1Vq1bReJYbjfLsUe00y08CoP3uP/74YxpbVVVF9aiWPup5z+Z5V1RU0Fgzo3pUq//CCy9QndWkP/PMMzQ2ysPv2bOH6uvWrUtqX/nKV2js22+/TfWoT0D0lpX1Xzh16hSN/elPf5rUtm3bhpqamiaf1PATM3d/PCHNi2KFEB0HbZcVIhNkdiEyQWYXIhNkdiEyQWYXIhOKun+1U6dONM20du1aGv/YY48ltU8++YTGRimmqC0xSyFNnTqVxm7cuJHqU6ZMoXrUrpmlBaMtylGaJ0oLLl68mOqzZ89OalH57XPPPdfi+wZ4eoylK4F4VPXp03wf2eXLl6nOyp779+dFpP369UtqLOWnK7sQmSCzC5EJMrsQmSCzC5EJMrsQmSCzC5EJMrsQmVDUPLuZ0XLNYcOG0XiWbx44cCCNPXPmDNWjEtdx48Ylte7du9PYKGcblRlHpb9PP/10UovG/0aji6NSzagl87lz6crnqOw44uWXeRuFefPShZkfffQRjY3OefScRnsjJk6cmNSikuja2tqkppHNQgiZXYhckNmFyASZXYhMkNmFyASZXYhMkNmFyISi5tkbGhponW+U62Y1yFF98Wc/+1mqd+nSheo7duxIatFI5ah2+u6776Y6y6tGetTmOsoH//KXv6R63759qc5GZffo0YPGRnsAonp4Nh784YcfprER27Zto/qBAweofs899yS11rSpZn0bdGUXIhNkdiEyQWYXIhNkdiEyQWYXIhNkdiEyQWYXIhOKmmcvKyuj/dmj2mmWfxwwYACNffPNN6kexbMRu1GOP+ppf+3aNapHfelZ7/eof3m0t4HlyYG4r/yVK1eS2jvvvENj77//fqpHPe/Z3oif//znNHbRokVUj9YW7a1g/2eiOn923yw2vLKb2fNmdtrMdt1w27Nm9r6ZbS/8a90OBSFEu9Ocl/EvAFjQxO0/cvephX98lIsQouSEZnf3jQDSvYWEEP8naM0HdF83s52Fl/nJ4VRmtszMKs2sMtrzK4RoP1pq9p8BGAdgKoBTAJanftHdV7h7hbtXREMGhRDtR4vM7u7V7l7v7g0AngMwo22XJYRoa1pkdjMbesOPiwHsSv2uEKJjEL6uNrPfArgfwCAzOwHgewDuN7OpABzAUQBfa87B3J3myqP6ZpavjnL0ly5dCtfGYDPWn3zySRrbs2dPqkc97aOa9JtuuimpTZs2jcb+6U9/onpUr3706FGqs/M+d+5cGhvtX4j2ALA8POvbDgB/+MMfqB7FR7MEKisrk9r06dNpLPv/wPZNhGZ398ebuPkXUZwQomOh7bJCZILMLkQmyOxCZILMLkQmyOxCZEJRt7Rdu3YNx44dS+pRCmr06NFJLdqKe9ddd1E9Sm8dOnQoqR0/fpzGRhw+fJjqUXrskUceSWq9e/emsVG75g8++IDqrHUxAPp8R+mpTZs2UX3+/PlULy8vT2rR42bPN8DTnUCckly6dGlS++tf/0pjWSqXjUTXlV2ITJDZhcgEmV2ITJDZhcgEmV2ITJDZhcgEmV2ITCh66xiWB4xyk3fccUdS2759O40dNmwY1c+ePUv1UaNGJbXz58/T2HfffZfqCxcupHq0dpbL3rdvH4295ZZbqB7l0aPSYpZLj9pcR3sjVq9eTXXWtjwa98z2LgDAuXO8LWNUGrxmzZqkFpVE9++f7AJHnw9d2YXIBJldiEyQ2YXIBJldiEyQ2YXIBJldiEyQ2YXIhKLm2bt27UrzmyyPDvCxygsWNDV78n/56KOPqD5u3Diqz549O6lt3ryZxkYtsqOcbcSsWbOS2q233kpjo1r6qqoqqn/mM5+h+pgxY5LayZMnaezkyZOpPnDgQKq/9NJLSS3aVzFo0CCqb9u2jerRGG52/9FzUlNTk9RYK2ld2YXIBJldiEyQ2YXIBJldiEyQ2YXIBJldiEyQ2YXIhKLm2d0dtbW1ST3Ku37hC19IalE9+9tvv031xYsXU/3ll19Oavfeey+NjXL4Ua6anTMA2Lp1a1KLRg9HueyZM2dS/Sc/+QnV77777qQW1dKzkcsAcOHCBaqz/Qds/DfARyoDwJw5c6i+YcMGqk+aNCmpsd4JAJ8jwEaPh1d2MxthZhvMbK+Z7Tazpwu3DzCzdWZ2sPA1XVEvhCg5zXkZXwfg2+4+EcBMAE+Z2SQA3wGw3t3HA1hf+FkI0UEJze7up9x9W+H7GgB7AQwDsAjAysKvrQTA+/gIIUrKP/We3cxGA5gGYDOAIe5+Cmj8g2BmNydilgFYBsT9yoQQ7UezP403s14AVgP4lrtfbG6cu69w9wp3r5DZhSgdzTK7mXVBo9F/7e6/L9xcbWZDC/pQAKfbZ4lCiLYgfBlvjb2EfwFgr7v/8AbpFQBLAfyg8DWdmyrQ0NAQjkZmsBTTbbfdRmM///nPUz1q97xjx46kFo3vZWWezWH58uVUf+KJJ5Iaa6cMABMmTKB61Cb7y1/+MtUvXbqU1KKxyGvXrqX6gw8+SHWWyn3ooYdobFRmunv3bqpH552NGI/KsVk7dtb6uznv2e8F8ASAd83sejL7u2g0+Soz+yqAKgCPNuO+hBAlIjS7u28CkPpzMa9tlyOEaC+0XVaITJDZhcgEmV2ITJDZhcgEmV2ITChqiauZ0S2zN9/c5I7bv8Fyn8OHD6exUUnjnXfeSfUzZ84ktTfeeIPG3nfffVSPdhY+9dRTVGejj0ePHk1jI9jjBuI8fJcuXZJaVNobPe4o1z137tykFrWSnjhxItWj8xLdf3V1dVKLWqqz1uQsB68ruxCZILMLkQkyuxCZILMLkQkyuxCZILMLkQkyuxCZUNQ8e+fOndG3b9+k3qdPHxrPxiZHNeURUd60V69eSe2ee+6hsWVl/DS/9957VB8yZAjV16xZk9Si/QPdu3enejTyeefOnVRntfx79+6lsVeuXKF6eXk51V999dWktmTJEhob9TeI9kZEI5/ZnpJdu3a16tgpdGUXIhNkdiEyQWYXIhNkdiEyQWYXIhNkdiEyQWYXIhOKmmevr6+nY3ZZb3YAGD9+fFKLctlslC0AdOvWjepsf8CUKVNo7J49e6jOcvhAfF6+8Y1vJLVr167R2Pr6eqpHo66j2mt2/1EPgqqqKqpH/fjZ/V+8yIcasb7uQDxu+v3336f6gAEDklq/fv1oLKtnZzl4XdmFyASZXYhMkNmFyASZXYhMkNmFyASZXYhMkNmFyITmzGcfAeBXAMoBNABY4e4/NrNnATwJ4Hoh+HfdnQ7Udnc0NDQkdZY/BIATJ04ktSjPfvXqVapH891Zf/SoLnv//v1UZ/3Ngbi/+tGjR5NaTU0NjR04cCDVo5ryffv2Uf3IkSNJLcqT9+zZk+pbtmyhOtvTEdX5R89p7969qV5bW0t1tocg2vPB9gCw/STN2VRTB+Db7r7NzHoD2Gpm6wraj9z935txH0KIEtOc+eynAJwqfF9jZnsBDGvvhQkh2pZ/6j27mY0GMA3A5sJNXzeznWb2vJn1T8QsM7NKM6tkL+GFEO1Ls81uZr0ArAbwLXe/COBnAMYBmIrGK//ypuLcfYW7V7h7BZtDJYRoX5rlPjPrgkaj/9rdfw8A7l7t7vXu3gDgOQAz2m+ZQojWEprdzAzALwDsdfcf3nD70Bt+bTEA3hJTCFFSLCr9NLNZAN4C8C4aU28A8F0Aj6PxJbwDOArga4UP89h90YNFpaIPPPBAUotG5EbpjKj1L0sTRam1qBzywIEDVB85ciTVjx07ltSix3XqFH3KwhbdbFw0wB/7oUOHaGzUKjoaF83SitG6o7Lj6LxFz9nBgweT2oQJE2js6tWrk9quXbtw+fJla0przqfxmwA0FUxz6kKIjoU+MRMiE2R2ITJBZhciE2R2ITJBZhciE2R2ITIhzLO3JZ06dXJWijpp0iQaz8bcslbPAC8DBeKcLsvLRseO8vDRCN4o58ty6dFI5mj/QVTiGpV6shLY6LxF26ujElhW3huVREfPSVSOHe0BYM9pdM63bt2a1D788EPU1tY2mWfXlV2ITJDZhcgEmV2ITJDZhcgEmV2ITJDZhcgEmV2ITChqnt3MzgC4sfh6EIAPi7aAf46OuraOui5Aa2spbbm2Ue4+uCmhqGb/h4ObVbp7RckWQOioa+uo6wK0tpZSrLXpZbwQmSCzC5EJpTb7ihIfn9FR19ZR1wVobS2lKGsr6Xt2IUTxKPWVXQhRJGR2ITKhJGY3swVmtt/MDpnZd0qxhhRmdtTM3jWz7WZWWeK1PG9mp81s1w23DTCzdWZ2sPC1yRl7JVrbs2b2fuHcbTezh0u0thFmtsHM9prZbjN7unB7Sc8dWVdRzlvR37ObWWcABwA8COAEgC0AHnf3PUVdSAIzOwqgwt1LvgHDzGYDuATgV+4+pXDbvwE45+4/KPyh7O/u/9pB1vYsgEulHuNdmFY09MYx4wAeAfAvKOG5I+t6DEU4b6W4ss8AcMjdD7v7NQC/A7CoBOvo8Lj7RgDnPnXzIgArC9+vRON/lqKTWFuHwN1Pufu2wvc1AK6PGS/puSPrKgqlMPswAMdv+PkEOta8dwfwupltNbNlpV5MEwy5Pmar8DXdq6s0hGO8i8mnxox3mHPXkvHnraUUZm+qP1ZHyv/d6+7TASwE8FTh5apoHs0a410smhgz3iFo6fjz1lIKs58AMOKGn4cDOFmCdTSJu58sfD0NYA063ijq6usTdAtfT5d4PX+jI43xbmrMODrAuSvl+PNSmH0LgPFmNsbMugJYAuCVEqzjHzCznoUPTmBmPQHMR8cbRf0KgKWF75cCeLmEa/k7OsoY79SYcZT43JV8/Lm7F/0fgIfR+In8ewCeKcUaEusaC2BH4d/uUq8NwG/R+LKuFo2viL4KYCCA9QAOFr4O6EBrexGNo713otFYQ0u0tllofGu4E8D2wr+HS33uyLqKct60XVaITNAOOiEyQWYXIhNkdiEyQWYXIhNkdiEyQWYXIhNkdiEy4X8AAKM8rRKqgMUAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + " \n", + "class UnFlatten(nn.Module):\n", + " def forward(self, x):\n", + " # Note -1 dimension is inferred.\n", + " return x.view(x.shape[0], -1, 1, 1)\n", + " \n", + " \n", + "class CNNVAE(VAE):\n", + " # TODO: find good architecture. \n", + " def __init__(self, latent_size=400):\n", + " super().__init__(latent_size)\n", + " \n", + " self.encoder = nn.Sequential(\n", + " nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2),\n", + " nn.ReLU(),\n", + " nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),\n", + " nn.ReLU(),\n", + " nn.Flatten()\n", + " )\n", + " \n", + " self.decoder = nn.Sequential(\n", + " UnFlatten(),\n", + " nn.ConvTranspose2d(self.latent_size, 32, kernel_size=5, stride=2),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1),\n", + " nn.Sigmoid()\n", + " )\n", + "\n", + "m = CNNVAE()\n", + "plt.imshow(m(next(iter(dl))[0])[0][0].view(28, 28).data.numpy(), cmap='Greys')" + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "metadata": {}, + "outputs": [], + "source": [ + "def loss_function(reconstruction_x, x, mu, log_var):\n", + " BCE = F.binary_cross_entropy(reconstruction_x, x.view(-1, 784), reduction='sum')\n", + "\n", + " # see Appendix B from VAE paper:\n", + " # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014\n", + " # https://arxiv.org/abs/1312.6114\n", + " # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n", + " KLD = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp())\n", + "\n", + " return BCE + KLD" + ] + }, + { + "cell_type": "code", + "execution_count": 177, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Epoch: 1 [0/60000 (0%)]\tLoss: 545.216492\n", + "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 304.467712\n", + "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 282.475281\n", + "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 270.869110\n", + "====> Epoch: 1 Average loss: 302.0279\n", + "Train Epoch: 2 [0/60000 (0%)]\tLoss: 268.287903\n", + "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 273.701294\n", + "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 257.615631\n", + "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 255.759094\n", + "====> Epoch: 2 Average loss: 262.8937\n", + "Train Epoch: 3 [0/60000 (0%)]\tLoss: 252.945618\n", + "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 264.578156\n", + "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 250.188477\n", + "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 252.424652\n", + "====> Epoch: 3 Average loss: 253.8822\n", + "Train Epoch: 4 [0/60000 (0%)]\tLoss: 247.508987\n", + "Train Epoch: 4 [19200/60000 (32%)]\tLoss: 260.838867\n", + "Train Epoch: 4 [38400/60000 (64%)]\tLoss: 247.394196\n", + "Train Epoch: 4 [57600/60000 (96%)]\tLoss: 248.601410\n", + "====> Epoch: 4 Average loss: 250.5403\n", + "Train Epoch: 5 [0/60000 (0%)]\tLoss: 243.754486\n", + "Train Epoch: 5 [19200/60000 (32%)]\tLoss: 258.409210\n", + "Train Epoch: 5 [38400/60000 (64%)]\tLoss: 245.839935\n", + "Train Epoch: 5 [57600/60000 (96%)]\tLoss: 248.752655\n", + "====> Epoch: 5 Average loss: 248.1759\n", + "Train Epoch: 6 [0/60000 (0%)]\tLoss: 242.354782\n", + "Train Epoch: 6 [19200/60000 (32%)]\tLoss: 256.875854\n", + "Train Epoch: 6 [38400/60000 (64%)]\tLoss: 245.116455\n", + "Train Epoch: 6 [57600/60000 (96%)]\tLoss: 247.758835\n", + "====> Epoch: 6 Average loss: 246.9103\n", + "Train Epoch: 7 [0/60000 (0%)]\tLoss: 241.033920\n", + "Train Epoch: 7 [19200/60000 (32%)]\tLoss: 256.162842\n", + "Train Epoch: 7 [38400/60000 (64%)]\tLoss: 242.574280\n", + "Train Epoch: 7 [57600/60000 (96%)]\tLoss: 245.996399\n", + "====> Epoch: 7 Average loss: 245.8804\n", + "Train Epoch: 8 [0/60000 (0%)]\tLoss: 239.238083\n", + "Train Epoch: 8 [19200/60000 (32%)]\tLoss: 255.056564\n", + "Train Epoch: 8 [38400/60000 (64%)]\tLoss: 242.117462\n", + "Train Epoch: 8 [57600/60000 (96%)]\tLoss: 245.257553\n", + "====> Epoch: 8 Average loss: 244.7513\n", + "Train Epoch: 9 [0/60000 (0%)]\tLoss: 238.755569\n", + "Train Epoch: 9 [19200/60000 (32%)]\tLoss: 253.945724\n", + "Train Epoch: 9 [38400/60000 (64%)]\tLoss: 241.528793\n", + "Train Epoch: 9 [57600/60000 (96%)]\tLoss: 243.930283\n", + "====> Epoch: 9 Average loss: 243.8302\n", + "Train Epoch: 10 [0/60000 (0%)]\tLoss: 238.498886\n", + "Train Epoch: 10 [19200/60000 (32%)]\tLoss: 253.728821\n", + "Train Epoch: 10 [38400/60000 (64%)]\tLoss: 241.584671\n", + "Train Epoch: 10 [57600/60000 (96%)]\tLoss: 244.546021\n", + "====> Epoch: 10 Average loss: 243.3323\n" + ] + } + ], + "source": [ + "latent_size = 250\n", + "m = VAE(latent_size)\n", + "m.cuda()\n", + "\n", + "def train(epoch, m, device, optimizer):\n", + " m.train()\n", + " train_loss = 0\n", + " for batch_idx, (x, _) in enumerate(dl):\n", + " x = x.to(device).view(-1, 784)\n", + " \n", + " optim.zero_grad()\n", + " recon_batch, mu, logvar = m(x)\n", + " loss = loss_function(recon_batch, x, mu, logvar)\n", + " loss.backward()\n", + " \n", + " train_loss += loss.item()\n", + " optimizer.step()\n", + " if batch_idx % 150 == 0:\n", + " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " epoch, batch_idx * len(x), len(dl.dataset),\n", + " 100. * batch_idx / len(dl),\n", + " loss.item() / len(x)))\n", + "\n", + " print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(dl.dataset)))\n", + "\n", + "\n", + "optim = torch.optim.Adam(m.parameters(), 1e-3)\n", + "\n", + "for epoch in range(1, 11):\n", + " train(epoch, m, 'cuda', optim)" + ] + }, + { + "cell_type": "code", + "execution_count": 180, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " # take 5 samples\n", + " sample = torch.randn(5, latent_size).to('cuda')\n", + " sample = m.decoder(sample).cpu().view(-1, 28, 28)" + ] + }, + { + "cell_type": "code", + "execution_count": 181, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAARZElEQVR4nO3dX4yc1XnH8d9jm782xja72GsbuiE2yFCpYEaoElVEFTUCbiAXqcJFRCVU5wKkRMpFEb0Il6hqEuWiiuQUFKdKiSIlCC5QG4QiWbmJWMDFBrs1WAs2XnbXxgaD+Wf76cW+VIvZ93mGeWfmnXK+H8na3Tl79j3z7vw8O/O85xxzdwH48lvW9gAADAdhBwpB2IFCEHagEIQdKMSKYR5sbGzMJycnh3lIoCjT09M6duyYLdXWKOxmdrukn0paLulf3f2R6PsnJyc1NTXV5JAAAp1Op7at5z/jzWy5pH+RdIek6yXdY2bX9/rzAAxWk9fst0h61d0PufvHkn4t6a7+DAtAvzUJ+yZJhxd9faS67TPMbIeZTZnZ1Pz8fIPDAWiiSdiXehPgc9feuvtOd++4e2d8fLzB4QA00STsRyRdtejrzZKONhsOgEFpEvbnJG01s6+Y2YWSvi3pqf4MC0C/9Vx6c/czZvaApP/UQuntMXd/uW8jA9BXjers7v60pKf7NBYAA8TlskAhCDtQCMIOFIKwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhhrqUNEZPtrFn1m625KrFXWnSF18cz+xAIQg7UAjCDhSCsAOFIOxAIQg7UAjCDhSCOvuXwLlz52rbzp49G/bN2pscW5KWL19e27ZiRfzwW7as2XMRdfzP4pkdKARhBwpB2IFCEHagEIQdKARhBwpB2IFCUGcfAdmc8ayW/fHHH9e2ZXX0rJadtWdjP336dG3bBRdcEPa98MILw/asTh8psQbfKOxmNi3plKSzks64e6cfgwLQf/14Zv9rdz/Wh58DYIB4zQ4UomnYXdLvzex5M9ux1DeY2Q4zmzKzqfn5+YaHA9CrpmG/1d23S7pD0v1m9rXzv8Hdd7p7x9074+PjDQ8HoFeNwu7uR6uPc5KekHRLPwYFoP96DruZrTSzyz79XNI3JO3r18AA9FeTd+PXS3qiqleukPTv7v4ffRnVl0zTOvqHH34Ytr/99tu1batXrw77ZrXurB6d3bdo7MePHw/7rl27NmxfuXJl2B7dtxLnyvccdnc/JOkv+jgWAANE6Q0oBGEHCkHYgUIQdqAQhB0oRDFTXLMS0SBl00w/+eSTsP2dd97puX3dunVh32ip525k5zWafvvee++FfZssUy3FY8umz45yaa3XsfHMDhSCsAOFIOxAIQg7UAjCDhSCsAOFIOxAIb40dfY26+jZsbM6e7TcsiS9++67YXs01TOryWa17EzWP7rv77//ftg3m9qb3bdoimzTJbSzYzep0w/qZ/PMDhSCsAOFIOxAIQg7UAjCDhSCsAOFIOxAIUaqzp7Vq5vU0pvWkyNnzpwJ2z/66KOwPVoKWpJOnDgRtkdz1rOxZeclm2v/wQcfhO3RNQLZ/Y7mwkv5edm0aVNt28aNG8O+2Vz5pnX4qH/Traxrj9lTLwD/7xB2oBCEHSgEYQcKQdiBQhB2oBCEHSjE0OvsUa08q/lGfbN6clYvzvpHddFsvvnrr78ets/NzYXtl156adgeybY1ztZuz+aUnzp1KmyfnZ2tbTt8+HDYN1svP1sn4Morr6xt27JlS9g3O+fZ9QUbNmwI29esWVPblm1VndXh66TP7Gb2mJnNmdm+RbetM7NnzOxg9TEeHYDWdfNn/C8k3X7ebQ9Ketbdt0p6tvoawAhLw+7uuyWdf13jXZJ2VZ/vknR3n8cFoM96fYNuvbvPSFL1sfbFkZntMLMpM5uan5/v8XAAmhr4u/HuvtPdO+7eGR8fH/ThANToNeyzZjYhSdXH+O1kAK3rNexPSbq3+vxeSU/2ZzgABiWts5vZ45JukzRmZkck/VDSI5J+Y2b3SXpD0re6PWBUS89q4VF7tgb5wYMHw/ZsrvzFF19c27Z3796w7549e8L21atXh+3Zy59rr722ti17n6TpmvXZvO6ZmZmej521Z9cvRHPKDxw4EPbN5tJntfCxsbGwffv27bVtk5OTYd9LLrmkti16HKdhd/d7apq+nvUFMDq4XBYoBGEHCkHYgUIQdqAQhB0oxFCnuLp7o9JbNJ3y5MmTYd+s/JUta3zRRRfVtmWlt2zqbrY0cDYNNStBRZouc52VmKIpstkU1uzxkC33HJVjs6m72RTWbBnrqOQoxed9YmIi7Nvrsug8swOFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UIihLyUdLf+bTVON+mZLGr/xxhthe1anj6ZyZvXg7Gdv27YtbM9qvq+99lptW7bscFarjq4vkPKtiaOpotk00mz6bPZ4eeutt2rbVq1aFfZtumXzihVxtKKxZ/crmsYatfHMDhSCsAOFIOxAIQg7UAjCDhSCsAOFIOxAIYY+nz2qSWfLFke1y2zedbT8riQdOXIkbI+WDs5qqtmSyNnc56zOHtWrsy2bsyWRs/uWbXUdbX2c/c6ynx0t7y3F8+Wz+eibN28O27NaeDa2qP+hQ4fCvjfccEPYXodndqAQhB0oBGEHCkHYgUIQdqAQhB0oBGEHCjH0Ons03zarJ0dzp7OabTS3WcrXZo+OnW2LnM2dzurs2fUHGzdurG3Lri+I6uBSPl/9+PHjYXtUz85q+NnvJFtvP1r/YHp6OuzbZPtwKR9b9HvJ1l7Ithevkz6zm9ljZjZnZvsW3fawmb1pZnuqf3f2dHQAQ9PNn/G/kHT7Erf/xN1vrP493d9hAei3NOzuvltS/DcygJHX5A26B8zsperP/NoLrM1sh5lNmdlU9voOwOD0GvafSfqqpBslzUj6Ud03uvtOd++4e+eKK67o8XAAmuop7O4+6+5n3f2cpJ9LuqW/wwLQbz2F3cwW7yn7TUn76r4XwGhI6+xm9rik2ySNmdkRST+UdJuZ3SjJJU1L+m43B3P3cF/srJ58+PDh2rZo7XRJOnjwYNierY8e7YF+7NixsO+6devC9qzenO2hHu3XnV1fcPTo0bA9q8Nn1xBE9ehsj/TVq1eH7QcOHAjbZ2dna9uyl5TZ9QVZrTt7LEf9s2NH8/yjn5uG3d3vWeLmR7N+AEYLl8sChSDsQCEIO1AIwg4UgrADhRjqFNdz586Fyx5nU0Uj2XTIyy+/PGzPli2Oxr1mzZqwbzQFVcq3B86Woo5Kf1mJKSv7ZUsiZ2OP2rPpt9mU52zb5Kh0d9lll/XcV8pLtdkS3tHvJXs8RaVWtmwGQNiBUhB2oBCEHSgEYQcKQdiBQhB2oBBDr7NH9fAXX3wx7L9+/fratmzKYdMlk6PaZzZVM1tWeGxsLGyP6qpSfo1BJJuKmR07q8OfOnXqC4/pU9n03Gzb5GhqcNY3u+Yjq8Nnj8fo8TQ3Nxf2jZbIjvDMDhSCsAOFIOxAIQg7UAjCDhSCsAOFIOxAIYZaZ1++fHk4j/jmm28O+7/55pu1bVmdPKv3ZrXLaIvfbN51Nj85mxudbUcd1dmzOnk0T1+SJiYmwvZsXnh03rNtj7P7nS2xHZ2X7NqErD1bHjxbPyE679njKbquI/p988wOFIKwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhhl5nj+qT11xzTdg/mheerUGe1WRPnjwZtke1z6zv1VdfHbZn1whk8+Gj/lnfrJ6cnddsLv6JEydq27J68vj4eNiebZUd1emzOvnmzZvD9rVr14btTdal37ZtW9g3WkMgWks/fWY3s6vM7A9mtt/MXjaz71W3rzOzZ8zsYPUxvvcAWtXNn/FnJP3A3bdJ+ktJ95vZ9ZIelPSsu2+V9Gz1NYARlYbd3Wfc/YXq81OS9kvaJOkuSbuqb9sl6e5BDRJAc1/oDTozm5R0k6Q/SVrv7jPSwn8Ikq6s6bPDzKbMbCp7jQVgcLoOu5mtkvRbSd9393iVwkXcfae7d9y9k72ZA2Bwugq7mV2ghaD/yt1/V908a2YTVfuEpHhJTACtSktvtlDXeVTSfnf/8aKmpyTdK+mR6uOT2c9atmxZOJ0zK71F5YysRPTKK6+E7dnSwlGZKNv2eMOGDWF7Nv02W5Y4mi6ZvXTK7nd27OznN1lKetWqVWF7tl10tHz48ePHw77Z7zQrWW7fvj1sj0p7nU4n7Ntr6a2bOvutkr4jaa+Z7alue0gLIf+Nmd0n6Q1J3+riZwFoSRp2d/+jpLqrNr7e3+EAGBQulwUKQdiBQhB2oBCEHSgEYQcKMdQprlI8HTPb/jda1jiryWZL+2ZbF+/evbu2LZs+my3HHE0DlaTTp0+H7dHYt27dGvbNrk/Ipmpmdfjo3GTLWGfLe2dbZa9cubK2Lbs+4Lrrrgvbs/MyOTkZtkfXXmzZsiXsG/3OonzxzA4UgrADhSDsQCEIO1AIwg4UgrADhSDsQCFGqs6ezU+O5upmS/tmdfisXnzTTTeF7U1k2ypnY4vq0dk5zWrZ2bGznx9ty5wdO/p9S/ky2dGc9OzY2Tba2diy5cGj85bNpY9+NnV2AIQdKAVhBwpB2IFCEHagEIQdKARhBwox9Dp7JKtNhjXEpO6Z1YMz2bzvJrL73URWJ8fSBvk7aQvP7EAhCDtQCMIOFIKwA4Ug7EAhCDtQCMIOFKKb/dmvkvRLSRsknZO0091/amYPS/p7SfPVtz7k7k8PaqBNfRnrpt0o9X7j87q5qOaMpB+4+wtmdpmk583smartJ+7+z4MbHoB+6WZ/9hlJM9Xnp8xsv6RNgx4YgP76Qq/ZzWxS0k2S/lTd9ICZvWRmj5nZkutCmdkOM5sys6n5+fmlvgXAEHQddjNbJem3kr7v7u9K+pmkr0q6UQvP/D9aqp+773T3jrt3xsfH+zBkAL3oKuxmdoEWgv4rd/+dJLn7rLufdfdzkn4u6ZbBDRNAU2nYbeHt3Ecl7Xf3Hy+6ffHWpN+UtK//wwPQL928G3+rpO9I2mtme6rbHpJ0j5ndKMklTUv67kBGCKAvunk3/o+SlirWjmxNHcDncQUdUAjCDhSCsAOFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UAjCDhTChrmlr5nNS3p90U1jko4NbQBfzKiObVTHJTG2XvVzbH/m7kuu/zbUsH/u4GZT7t5pbQCBUR3bqI5LYmy9GtbY+DMeKARhBwrRdth3tnz8yKiObVTHJTG2Xg1lbK2+ZgcwPG0/swMYEsIOFKKVsJvZ7Wb232b2qpk92MYY6pjZtJntNbM9ZjbV8lgeM7M5M9u36LZ1ZvaMmR2sPi65x15LY3vYzN6szt0eM7uzpbFdZWZ/MLP9ZvaymX2vur3VcxeMayjnbeiv2c1suaT/kfQ3ko5Iek7SPe7+ylAHUsPMpiV13L31CzDM7GuS3pP0S3f/8+q2f5L0trs/Uv1Hudbd/2FExvawpPfa3sa72q1oYvE245LulvR3avHcBeP6Ww3hvLXxzH6LpFfd/ZC7fyzp15LuamEcI8/dd0t6+7yb75K0q/p8lxYeLENXM7aR4O4z7v5C9fkpSZ9uM97quQvGNRRthH2TpMOLvj6i0drv3SX93syeN7MdbQ9mCevdfUZaePBIurLl8Zwv3cZ7mM7bZnxkzl0v25831UbYl9pKapTqf7e6+3ZJd0i6v/pzFd3pahvvYVlim/GR0Ov25021EfYjkq5a9PVmSUdbGMeS3P1o9XFO0hMava2oZz/dQbf6ONfyeP7PKG3jvdQ24xqBc9fm9udthP05SVvN7CtmdqGkb0t6qoVxfI6ZrazeOJGZrZT0DY3eVtRPSbq3+vxeSU+2OJbPGJVtvOu2GVfL56717c/dfej/JN2phXfkX5P0j22MoWZc10j6r+rfy22PTdLjWviz7hMt/EV0n6QrJD0r6WD1cd0Ije3fJO2V9JIWgjXR0tj+SgsvDV+StKf6d2fb5y4Y11DOG5fLAoXgCjqgEIQdKARhBwpB2IFCEHagEIQdKARhBwrxv4tXhwcIGyt4AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAR/ElEQVR4nO3dX4xd1XXH8d/C5o8x/u/xHwx4ACEoqgSYEapEFVFFjYAXk4dU4SGiEpLzAFIi5aEofQiPqGoS9aGK5BQUt0qJIiUIHlAbQJFQXiwGcMBgtXbtARyPxzMG/If/tlcf5hBNzJy1Lufce8919vcjje7M3bPv2T73LN+5d+21t7m7APz5u6jrAQAYDoIdKATBDhSCYAcKQbADhVg6zIOtX7/ex8fHh3lIoChTU1Oam5uzxdpaBbuZ3S3pXyQtkfRv7v5Y9Pvj4+OanJxsc8jOkKJsJjpvZotek0PR5bEHaWJiorat8Z/xZrZE0r9KukfSzZLuN7Obmz4egMFq8579DkkH3P2gu38q6ReStvdnWAD6rU2wb5H0zoKfD1f3/Qkz22Fmk2Y2OTs72+JwANpoE+yLven5whs0d9/p7hPuPjE2NtbicADaaBPshyVdveDnqyQdaTccAIPSJthfknSDmV1rZpdI+qakZ/ozLAD91jj15u5nzOxhSf+t+dTbE+7+Rt9GNmRtUmtZ36y9bRpoVNNbmbbnZZDp0FE+b021yrO7+7OSnu3TWAAMENNlgUIQ7EAhCHagEAQ7UAiCHSgEwQ4UYqj17Beyc+fOdT2EWoPMCXc5B2CQ8xfa5ugvxDw8r+xAIQh2oBAEO1AIgh0oBMEOFIJgBwpRTOotS7UMMrU26JVp26SYBl1GGvVvm1prc+zMRRe1ex0cxdQcr+xAIQh2oBAEO1AIgh0oBMEOFIJgBwpBsAOFuKDy7G3yplke/dNPPw3bo7zrIMtA+/H4bY7d5e61XZahZsdeujQOnSVLloTt0dgG9Xzzyg4UgmAHCkGwA4Ug2IFCEOxAIQh2oBAEO1CIkcqzt6k5P3PmTNj31KlTYft7770Xtl988cW1bZdeemnYN6uNzv7dWf8s5xv57LPPwvazZ8+G7W3qvrNcdHbsLpf3jq4HSVqxYkXYHj1n2Tltes5bBbuZTUk6JemspDPuPtHm8QAMTj9e2f/G3ef68DgABoj37EAh2ga7S/qNmb1sZjsW+wUz22Fmk2Y2OTs72/JwAJpqG+x3uvs2SfdIesjMvnL+L7j7TnefcPeJsbGxlocD0FSrYHf3I9XtMUlPSbqjH4MC0H+Ng93MlpvZis+/l/Q1SXv7NTAA/dXm0/iNkp6qam+XSvpPd/+vrFOUU87yqlHN+ccffxz23b9/f9j+6quvhu2bNm2qbfvkk0/CvpdffnnYvnz58rA9y+NHslx2dt6y5yQb22WXXVbb1rZePZsj0GbN+iNHjoTtW7duDduvueaasH3VqlW1bdk5bVrv3jjY3f2gpFua9gcwXKTegEIQ7EAhCHagEAQ7UAiCHSjE0Etco7LEDz/8MOx74sSJ2raZmZmw7/PPPx+2Hzx4MGxfv359bdu7774b9l2zZk3YvnHjxrA9S8199NFHtW1ZGidLQZ08eTJsz0o5o7Fn/64spRldD1J8PWVpu6zk+dChQ2H7ddddF7Zv27atti261qT4nEfPJ6/sQCEIdqAQBDtQCIIdKATBDhSCYAcKQbADhRhqnt3dw5LJLG8alR3u3r077Pvmm2+G7dnyvFEeP5sfkOV0s5LFNktNZ7nsLJ98ySWXhO3ZHIOovDebf5Cdl6wMNXrOsqWgsxz/22+/HbZn8xOiVZs++OCDsO+NN95Y20aeHQDBDpSCYAcKQbADhSDYgUIQ7EAhCHagEEPPs0fLQWc522jb5WPHjoV9s6WksxriaA5AlqPP8skrV64M27OccJSPjs63FC/1LOVLSUe19FK8VHV23rJlsLM8e3RNZMt7Z3n27LwcPXo0bJ+eng7bI9FzSp4dAMEOlIJgBwpBsAOFINiBQhDsQCEIdqAQI1XPfvr06bB/lOvOaoCjHL2U55uz+uVIlpPdsGFD2L5s2bKwPaqnzx77nXfeCdvXrl0btme58ug5Xbo0vvyyevZsnYC33nqrti3L4UdbKkvS6tWrw/bsvEfnLdtG+8yZM7VtrfLsZvaEmR0zs70L7ltrZs+Z2f7qNp41AqBzvfwZ/zNJd5933yOSXnD3GyS9UP0MYISlwe7uL0o6fx7rdkm7qu93Sbqvz+MC0GdNP6Db6O7TklTd1r5BMbMdZjZpZpPHjx9veDgAbQ3803h33+nuE+4+sW7dukEfDkCNpsE+Y2abJam6jUvOAHSuabA/I+mB6vsHJD3dn+EAGJQ0z25mT0q6S9J6Mzss6QeSHpP0SzN7UNLbkr7R6wGjPGBWAxy1z83NhX2z/duzmvIot5nli7Na+WwP9SzfHNW7Z3X+2bHb7u8etZ87dy7s2zYXHs2dyNZOyPLo2XnJ1uOPrtfseormlETnNA12d7+/pumrWV8Ao4PpskAhCHagEAQ7UAiCHSgEwQ4UYuglrlFqINv6OFreN1syOSvVzMpro2Nnj50tJZ2VamZpnmhb5aw8NlsKOiv9zZZcjpZszkp/s9RbltLcunVrbVvb5b/bbvkctWfPSZayrMMrO1AIgh0oBMEOFIJgBwpBsAOFINiBQhDsQCGGmmeX4txqljeNyiWzvsuXLw/boxJWKV7G+tprrw37ZstYZznbrD0qeWy7LXLWvmnTprA9Woosm1+QPSdZ/+hay/rOzs6G7VlJ9BVXXBG2R3MjsrLhaOxs2QyAYAdKQbADhSDYgUIQ7EAhCHagEAQ7UIih59mjZZGzbZej5X2z3GSWh8+2yV2xYkVtW5ZzzerRs5rzTFQznp3TrF49y3VH+WIpzvNnue5sSeWsf5Trzq6X6PmWpLa7G0U169lzEp3zKL54ZQcKQbADhSDYgUIQ7EAhCHagEAQ7UAiCHSjE0PPs0ZrXWQ3w1NRUbduhQ4fCvu+//37Ynq07n9WFR7KcbrYO+MmTJxs/flbHn9WrZ7nubH5CtHVx9nxn2ypnY49yztnzne1hkK15n639Hj1+NG9Cyq+nOukVbGZPmNkxM9u74L5HzewPZran+rq30dEBDE0vL1c/k3T3Ivf/2N1vrb6e7e+wAPRbGuzu/qKk+O8pACOvzQd0D5vZa9Wf+bUbY5nZDjObNLPJaD0yAIPVNNh/Iul6SbdKmpb0w7pfdPed7j7h7hNtiwcANNco2N19xt3Puvs5ST+VdEd/hwWg3xoFu5ltXvDj1yXtrftdAKMhzbOb2ZOS7pK03swOS/qBpLvM7FZJLmlK0rd7OVi2P3uWb45yk1Gtu5Tn2bN8dPR5Q7Y/e1afnOWys5xvVNfd9rGzfcaz8xbNEchy1W2f0+h6yurws3UAsjr/bK+AI0eO1LatWrUq7BudtygHnwa7u9+/yN2PZ/0AjBamywKFINiBQhDsQCEIdqAQBDtQiKGXuEaypYGjdEaWQsrKKbM0UFR2mKVxspLEbKnpLMUU9c/KQLNjty23jJbwzspMoxJVKd/KOkphrVlTO8NbUp6yzEqes5Rk1D8751mKuvaYjXoBuOAQ7EAhCHagEAQ7UAiCHSgEwQ4UgmAHCjFSWzZn+eQTJ07Uth09ejTsmy3tm4nyrtmxs+1/M9nYo3LLrEQ1m9uQ5aOzXHdUKprl6LMy00y2zHVkbm4ubM9y4dnci2heRzZnpOmy5ryyA4Ug2IFCEOxAIQh2oBAEO1AIgh0oBMEOFGLoefYot5otuRzlTU+fPh32bZOrluLcZ1Z3ndWUZ7nqLI8f5dKzZYmzXHT2b8vyyQcOHGjcNzsvWa47ek6jraSlfLvoK6+8MmzPHj+qSb/pppsa943wyg4UgmAHCkGwA4Ug2IFCEOxAIQh2oBAEO1CIkVo3ft26dWH7li1battuu+22sG+WR8/qumdmZmrbsnzx9ddfH7ZnY8vWvI/WX8/Whc/WR9+4cWPYnj1+VHvd5t8ltZuXkY07er4laXx8PGzP5giMjY3Vtm3bti3sG20RHs3pSF/ZzexqM/utme0zszfM7DvV/WvN7Dkz21/dxqscAOhUL3/Gn5H0PXf/C0l/JekhM7tZ0iOSXnD3GyS9UP0MYESlwe7u0+7+SvX9KUn7JG2RtF3SrurXdkm6b1CDBNDel/qAzszGJd0mabekje4+Lc3/hyBpQ02fHWY2aWaTx48fbzdaAI31HOxmdoWkX0n6rruf7LWfu+909wl3n8g+gAMwOD0Fu5ldrPlA/7m7/7q6e8bMNlftmyUdG8wQAfRDmnqz+RrHxyXtc/cfLWh6RtIDkh6rbp9OD7Z0abg08S233BL2X716dW3bypUrw75ZGmd2djZsj9Jn2dK+t99+e9ielSxmZapZ+iySlf5u2LDou7M/ypaqbrMcdDa2bFvkqH9WNpwtJX3VVVeF7dFW1VK8vHiW1lu2bFltW3Qt9nKV3CnpW5JeN7M91X3f13yQ/9LMHpT0tqRv9PBYADqSBru7/05S3QoGX+3vcAAMCtNlgUIQ7EAhCHagEAQ7UAiCHSjESG3ZnOXKo1x3lovOZCWuUT45yntKeZlotnVxtpxzlFvN5gC0XQY7G3vTZY97kY09mltx8mQ8CTTaUrmXY2dzH6IS2Kz0NyrPDa+F8FEB/Nkg2IFCEOxAIQh2oBAEO1AIgh0oBMEOFGKoeXYzC/OPWc42ql9us6Tx52NrKstFZ+0YjGj+Q5bLbjv3oU3/7LGbXqu8sgOFINiBQhDsQCEIdqAQBDtQCIIdKATBDhRipOrZM1G+um1uss24BpnDR3Ndnve2efpB9OWVHSgEwQ4UgmAHCkGwA4Ug2IFCEOxAIQh2oBBpsJvZ1Wb2WzPbZ2ZvmNl3qvsfNbM/mNme6uvetoMxs/CrS6M6LjTT9lpz9/CrTf9B6WVSzRlJ33P3V8xshaSXzey5qu3H7v7PAxsdgL7pZX/2aUnT1fenzGyfpC2DHhiA/vpS79nNbFzSbZJ2V3c9bGavmdkTZramps8OM5s0s8nZ2dlWgwXQXM/BbmZXSPqVpO+6+0lJP5F0vaRbNf/K/8PF+rn7TnefcPeJsbGxPgwZQBM9BbuZXaz5QP+5u/9aktx9xt3Puvs5ST+VdMfghgmgrV4+jTdJj0va5+4/WnD/5gW/9nVJe/s/PAD90sun8XdK+pak181sT3Xf9yXdb2a3SnJJU5K+3csBo7RGl9v/Ztos/YsLT3YtdlnC2vSxe/k0/neSFnv0ZxsdEUAnmEEHFIJgBwpBsAOFINiBQhDsQCEIdqAQQ19KOpLlD6PtnrMc/CCXksaFJ3u+s+XB215vXeCVHSgEwQ4UgmAHCkGwA4Ug2IFCEOxAIQh2oBA2yKVrv3Aws1lJby24a72kuaEN4MsZ1bGN6rgkxtZUP8e21d0XXf9tqMH+hYObTbr7RGcDCIzq2EZ1XBJja2pYY+PPeKAQBDtQiK6DfWfHx4+M6thGdVwSY2tqKGPr9D07gOHp+pUdwJAQ7EAhOgl2M7vbzP7HzA6Y2SNdjKGOmU2Z2evVNtSTHY/lCTM7ZmZ7F9y31syeM7P91e2ie+x1NLa+b+PdcGx124x3eu6Guf35oscf9nt2M1si6X8l/a2kw5JeknS/u7851IHUMLMpSRPu3vkEDDP7iqTTkv7d3f+yuu+fJL3r7o9V/1Gucfd/GJGxPSrpdNfbeFe7FW1euM24pPsk/b06PHfBuP5OQzhvXbyy3yHpgLsfdPdPJf1C0vYOxjHy3P1FSe+ed/d2Sbuq73dp/mIZupqxjQR3n3b3V6rvT0n6fJvxTs9dMK6h6CLYt0h6Z8HPhzVa+727pN+Y2ctmtqPrwSxio7tPS/MXj6QNHY/nfOk23sN03jbjI3Pummx/3lYXwb7Y4lyjlP+70923SbpH0kPVn6voTU/beA/LItuMj4Sm25+31UWwH5Z09YKfr5J0pINxLMrdj1S3xyQ9pdHbinrm8x10q9tjHY/nj0ZpG+/FthnXCJy7Lrc/7yLYX5J0g5lda2aXSPqmpGc6GMcXmNny6oMTmdlySV/T6G1F/YykB6rvH5D0dIdj+ROjso133Tbj6vjcdb79ubsP/UvSvZr/RP7/JP1jF2OoGdd1kn5ffb3R9dgkPan5P+s+0/xfRA9KWifpBUn7q9u1IzS2/5D0uqTXNB9Ymzsa219r/q3ha5L2VF/3dn3ugnEN5bwxXRYoBDPogEIQ7EAhCHagEAQ7UAiCHSgEwQ4UgmAHCvH/dbjmuVN/Y4sAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAARiElEQVR4nO3dW4zd1XXH8d/CgDHGd48vMhdzp1AJgsaoElXkEhoZeACEUsJDRCVU5wFEIuWhiD6EJ4RQCcpDFckpKAYFokjEggeLxrIiIMIODLYLBlNMbUNsBs/Y+DJgjG+rD3OoBjP/tYbzPze8vx/JmpmzZs/ZPp6fz8xZ/723ubsAnPpO6/YEAHQGYQcKQdiBQhB2oBCEHSjE6Z28s7lz5/rixYs7eZdAUXbs2KE9e/bYeLVaYTezZZJ+KWmSpP9090eiz1+8eLEGBgbq3CWAQH9/f2Wt6R/jzWySpP+QdJOkKyXdZWZXNvv1ALRXnd/Zr5P0vrtvc/cjkn4n6dbWTAtAq9UJ+yJJfx3z8c7GbV9hZsvNbMDMBoaHh2vcHYA66oR9vBcBvnbtrbuvcPd+d+/v6+urcXcA6qgT9p2Szhvz8bmSPqo3HQDtUifsr0u61MwuNLMzJf1Q0gutmRaAVmu69ebux8zsPkn/pdHW25Pu/nbLZgagpWr12d19taTVLZoLgDbiclmgEIQdKARhBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEB3dShq9JzvY8/jx47XqZuPuaixJOu20es810dfO1D3QNBt/9OjRsB793c8888ymx4bjmhoF4FuHsAOFIOxAIQg7UAjCDhSCsAOFIOxAIeizn+JOnDgR1g8ePBjWV6+ONw/etm1bWJ8zZ05lLevRZ2bNmhXWp02bVlnbtWtXODZ7XKZMmRLW9+3bF9aXLVtWWbvqqqvCsWeffXZYr8IzO1AIwg4UgrADhSDsQCEIO1AIwg4UgrADhaDPfoo7duxYWN+4cWNYf/rpp8P6O++8E9Y/++yzytrpp8ffftl69awe9fH379/f9Fgp7/FPnz49rO/cubOy9vDDD4djm+2z1wq7me2QNCLpuKRj7t5f5+sBaJ9WPLP/g7vvacHXAdBG/M4OFKJu2F3SH83sDTNbPt4nmNlyMxsws4Hh4eGadwegWXXDfr27XyvpJkn3mtl3T/4Ed1/h7v3u3t/X11fz7gA0q1bY3f2jxtshSaskXdeKSQFovabDbmZTzWzal+9L+r6kza2aGIDWqvNq/HxJqxq9ztMlPePuL7ZkVmiZrF/87rvv1qoPDg6G9ajPn+29fsYZZ4T1yZMnh/Wojz9p0qRw7JEjR8J6tt492zc+uv4g2xe+2f3ymw67u2+TdHWz4wF0Fq03oBCEHSgEYQcKQdiBQhB2oBAscT0FRC2sAwcOhGPXrVsX1rMtkbP2WLSVddb+yrbBzo42jlpvWWsta39lS1izxyWae7b0t1k8swOFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UAj67Ke4aCnlROrZcso6y1SzXvThw4fDetaHj5x11llhPduuOevxz5w5M6yPjIxU1rLHtFk8swOFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UAj67KeAqN+8a9eucOyHH34Y1rMjn7Oe8JQpUyprWZ88W9edrTnP1svXue/sGoHM9u3bK2tRD16SZsyY0dR98swOFIKwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAh6LOfAqJe+Pr168Oxu3fvDutZHz3rR2dHF0ey46azPn10pHP298rWq2fmzZsX1oeGhiprH3zwQTh20aJFTc0pfWY3syfNbMjMNo+5bbaZrTGzrY23s5q6dwAdM5Ef438jadlJtz0gaa27XyppbeNjAD0sDbu7vyzpk5NuvlXSysb7KyXd1uJ5AWixZl+gm+/ug5LUeFv5C4qZLTezATMbGB4ebvLuANTV9lfj3X2Fu/e7e39fX1+77w5AhWbDvtvMFkpS4231S4sAekKzYX9B0t2N9++W9HxrpgOgXdI+u5k9K2mppLlmtlPSzyU9Iun3ZnaPpA8l/aCdkyxd1hOOzmB/5ZVXwrGHDh0K69ma8Wx/9WhtdtQHl/IeflaP+vRZj/6LL74I69n57NnjFt3/c889F45dsmRJZS36XknD7u53VZS+l40F0Du4XBYoBGEHCkHYgUIQdqAQhB0oBEtcvwWypZ4bNmyorEVLKaV8u+WshZS1BaOjkbMjmbNlptl9R49b3SWsmU8+OXk5yVdFS383bdoUjo2O2Y7+zjyzA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhCDtQCPrsPSDrF3/++edhfdWqVZW1bdu2hWOzpZ5Rn1zKj3SOloqec845tb52dv1BND67viBb4prN7dNPPw3r0ZHP2b9ZtNX0kSNHKms8swOFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UAj67D0g63W/+uqrYf21116rrGXbLddd1531qyNmFtanTp0a1rPrD6K5ZY9LJuuzZ3OP5rZnz55wbLSePfpe4pkdKARhBwpB2IFCEHagEIQdKARhBwpB2IFC0GfvgKyPnu3t/thjj4X1aI/y7FjkrE+e9cKj9dNSvOY8m1u2Xj0TrUmP9m2X8h5+1qfft29fWI/2CcjWwkfHZEf7/KfP7Gb2pJkNmdnmMbc9ZGa7zGxT48/N2dcB0F0T+TH+N5KWjXP74+5+TePP6tZOC0CrpWF395clxWfZAOh5dV6gu8/M3mz8mD+r6pPMbLmZDZjZwPDwcI27A1BHs2H/laSLJV0jaVBS5StI7r7C3fvdvb+vr6/JuwNQV1Nhd/fd7n7c3U9I+rWk61o7LQCt1lTYzWzhmA9vl7S56nMB9Ia0z25mz0paKmmume2U9HNJS83sGkkuaYekH7dxjh2R7d0eyfro2Vndjz76aFjfuHFjWI/WpE+ZMiUcm/XJs3XbWT3qZ2f3nfWyo36zJO3du7eylq03z/bLP3jwYFjPriGI/l2yax/mzp1bWYseszTs7n7XODc/kY0D0Fu4XBYoBGEHCkHYgUIQdqAQhB0oRE8tcc3aX1GLK2t/1fnaUtwmio7QlaSnnnoqrL/44othPTreV4pbNdlSzUOHDoX1rLVWp32WtZjq/L0lacGCBZW1qH0lSbt37w7r2eMye/bssB4dV51t7z1jxozKWvSY8MwOFIKwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAheqrPfvjw4bAeLRUdGRkJx2b95Gz73sHBwcraSy+9FI5dt25dWM9MmzYtrEc932g75WzsRGTLUKOtqLNtqqNetCRdcMEFYX3hwoWVtexxiY5FlvLvl+y6jWjpb7a8NurDR48pz+xAIQg7UAjCDhSCsAOFIOxAIQg7UAjCDhSio312dw/XP69ZsyYcv3bt2qbvO+sH79mzJ6xHfdOtW7eGY7PrB7J+ctYTjv5u2Xr2bMvjrBeeHascHSF87rnnhmOvvfbasL506dKwPnPmzMpa9r20ZcuWsJ710bO19lEvPdsiu9lrF3hmBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEB3tsx87dkzDw8OV9ccffzwcH+3lnfUmM1k/efr06ZW17PjeTDY+62VH1y5kY7P17NnjktUvv/zyytoNN9wQjr3jjjvCerQvvBRfn7B58+Zw7Lx588J6tt49OxI6+vpz5swJxzYrfWY3s/PM7E9mtsXM3jaznzRun21ma8xsa+PtrLbMEEBLTOTH+GOSfubufyPp7yTda2ZXSnpA0lp3v1TS2sbHAHpUGnZ3H3T3DY33RyRtkbRI0q2SVjY+baWk29o1SQD1faMX6MxssaTvSPqLpPnuPiiN/ocgadxfQsxsuZkNmNnA3r17680WQNMmHHYzO0fSc5J+6u4TfkXK3Ve4e7+797frhQcAuQmF3czO0GjQf+vuf2jcvNvMFjbqCyUNtWeKAFohbb3ZaG/lCUlb3P0XY0ovSLpb0iONt89nX+vo0aMaGqr+P+G9994Lx0fLArMlhdkRvdFSTCk+8jk7DnrKlClhPWvj1FlGmh17nC3VnDUrbrJccsklYf3GG2+srN15553h2OzfLBO13ubPnx+O7evrC+vZVtLZ92PUmjv//PPDsVm7s8pE+uzXS/qRpLfMbFPjtgc1GvLfm9k9kj6U9IOmZgCgI9Kwu/ufJVX9V/K91k4HQLtwuSxQCMIOFIKwA4Ug7EAhCDtQiI4ucZ00aVK4bfIVV1wRjt+3b19lLVuqGS2PlfItlaNjk+tu15z14bOtqKM+e9bvnT17dli/+uqrw/ott9wS1m+//fbKWrYMNLv2IbtGIPqeyLYWzx63bG7ZMdvR98SFF14Yjm0Wz+xAIQg7UAjCDhSCsAOFIOxAIQg7UAjCDhSio332yZMn66KLLqqsP/PMM+H47du3V9Y2btwYjl2/fn1Yj7a4lqTBwcHKWnak8sjISFjP1idnffY66/wvu+yysH7//feH9f7+/rCeXWMQaXbd9peiPny2Hv3QoUNhPXtcsz0IouOksx5+s3hmBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEJbted5K/f39PjAw0PT4Onu311n7LEn79++vrGX73UdHKkvSrl27wnomOq764osvDsdGRypL+Vr7dvWEJ6LOv/mBAwfCsR9//HFYr3vtRLQnfranffTvvWTJEg0MDIx75zyzA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhCDtQiImcz36epKckLZB0QtIKd/+lmT0k6V8kfbkQ/EF3X92uiTbm0lRNyvvB2T7iCxYsaKqG9sn+zaM159l++Vn922gim1cck/Qzd99gZtMkvWFmaxq1x93939s3PQCtMpHz2QclDTbeHzGzLZIWtXtiAFrrG/3ObmaLJX1H0l8aN91nZm+a2ZNmNqtizHIzGzCzgWzrJwDtM+Gwm9k5kp6T9FN3PyjpV5IulnSNRp/5HxtvnLuvcPd+d+/v6+trwZQBNGNCYTezMzQa9N+6+x8kyd13u/txdz8h6deSrmvfNAHUlYbdRl/yfELSFnf/xZjbF475tNslbW799AC0ykRejb9e0o8kvWVmmxq3PSjpLjO7RpJL2iHpx22ZIYCWmMir8X+WNF5Ds609dQCtxRV0QCEIO1AIwg4UgrADhSDsQCEIO1AIwg4UgrADhSDsQCEIO1AIwg4UgrADhSDsQCEIO1CIjh7ZbGbDkj4Yc9NcSXs6NoFvplfn1qvzkphbs1o5twvcfdz93zoa9q/dudmAu/d3bQKBXp1br85LYm7N6tTc+DEeKARhBwrR7bCv6PL9R3p1br06L4m5Nasjc+vq7+wAOqfbz+wAOoSwA4XoStjNbJmZ/Y+ZvW9mD3RjDlXMbIeZvWVmm8xsoMtzedLMhsxs85jbZpvZGjPb2ng77hl7XZrbQ2a2q/HYbTKzm7s0t/PM7E9mtsXM3jaznzRu7+pjF8yrI49bx39nN7NJkt6T9I+Sdkp6XdJd7v5ORydSwcx2SOp3965fgGFm35X0qaSn3P1vG7c9KukTd3+k8R/lLHf/1x6Z20OSPu32Md6N04oWjj1mXNJtkv5ZXXzsgnn9kzrwuHXjmf06Se+7+zZ3PyLpd5Ju7cI8ep67vyzpk5NuvlXSysb7KzX6zdJxFXPrCe4+6O4bGu+PSPrymPGuPnbBvDqiG2FfJOmvYz7eqd46790l/dHM3jCz5d2ezDjmu/ugNPrNI2lel+dzsvQY70466Zjxnnnsmjn+vK5uhH28o6R6qf93vbtfK+kmSfc2flzFxEzoGO9OGeeY8Z7Q7PHndXUj7DslnTfm43MlfdSFeYzL3T9qvB2StEq9dxT17i9P0G28HeryfP5fLx3jPd4x4+qBx66bx593I+yvS7rUzC40szMl/VDSC12Yx9eY2dTGCycys6mSvq/eO4r6BUl3N96/W9LzXZzLV/TKMd5Vx4yry49d148/d/eO/5F0s0Zfkf9fSf/WjTlUzOsiSf/d+PN2t+cm6VmN/lh3VKM/Ed0jaY6ktZK2Nt7O7qG5PS3pLUlvajRYC7s0t7/X6K+Gb0ra1Phzc7cfu2BeHXncuFwWKARX0AGFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UIj/AweWDio91NCDAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAARIklEQVR4nO3dX2zd9XnH8c9DghNwnH/EyRyHJoCCGH80WplogqliqlYRbkgvOpWLiiG09AKkVurFELsol2haW/ViqhQGajp1VJVaBBdoK0KVUG8qDEpDsmgLgxBMHMcmREnIP5I8u/CPyQk+z2PO75zzO/X3/ZIs2+fxz+fr43zyOz7P7/v9mrsLwOJ3TdMDANAbhB0oBGEHCkHYgUIQdqAQS3t5Z+vWrfMtW7b08i77woULF2rVz58/3/bxZhYeOzAwENazbs3Fixfbrl977bXhscuXLw/rg4ODYX3JkiVhfTE6dOiQZmZm5v2l1wq7mT0g6SeSlkj6V3d/Jvr6LVu2aHx8vM5d9qXLly+H9YmJibD+/vvvh/X33nsvrH/wwQcta0uXxr/iTZs2hfVPP/00rJ84cSKsT01Ntaxt3LgxPPa2224L69u2bQvrK1asaFlbrP8RjI2Ntay1/TTezJZI+hdJ2yXdLulhM7u93e8HoLvq/M2+TdI77v6uu1+Q9EtJD3VmWAA6rU7YRyXNff44Ud12BTPbaWbjZjY+PT1d4+4A1FEn7PO9CPC5V3PcfZe7j7n72PDwcI27A1BHnbBPSLpxzuebJB2pNxwA3VIn7G9I2mpmN5nZgKRvSXq5M8MC0Gltt97c/aKZPSHpPzXbenve3fd3bGR9Juo3f/LJJ+Gxzz77bFjfu3dvWM9aczMzMy1rH374YXhs1prL+vRZnz3qhWettbvuuiusHzkSP5HcsWNHy1rUlpMWZ2uuVp/d3V+R9EqHxgKgi7hcFigEYQcKQdiBQhB2oBCEHSgEYQcK0dP57P0sm7cdzRk/ePBgeOzk5GRYz+ZtZ/3oaBrqxx9/HB57+PDhsH7u3LmwfvPNN4f1aArt9ddfHx578uTJsL5/f3xZxx133NGydvvt8QTN6667Lqxn1x/0I87sQCEIO1AIwg4UgrADhSDsQCEIO1AIWm+VbDnnaAXZrL2VTYHN6iMjI2H9+PHjLWujo59bKewKa9euDeuZrEUVtQWzpaSz38lHH30U1qNl0E6fPh0em42t7tTgJnBmBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEMX02bOdVrN61C+uu61V1k9euXJlWL906VLLWjZFNdulNbvvrJ8cTWPNpvZm9531uqOtrrNtsLPHJVtqmj47gMYQdqAQhB0oBGEHCkHYgUIQdqAQhB0oRDF99qgXLeVLSUdbE2fz0c+cORPWs37ywMBAWB8aGmpZy3rRy5YtC+tZvzj7/tHYs1519rjU7YXXkV2Xcc01/XcerRV2Mzsk6ZSkS5IuuvtYJwYFoPM6cWb/a3ef6cD3AdBF/fdcA0BX1A27S/qtmb1pZjvn+wIz22lm42Y2XvcacgDtqxv2+9z9K5K2S3rczL569Re4+y53H3P3seHh4Zp3B6BdtcLu7keq98ckvShpWycGBaDz2g67mQ2a2dBnH0v6uqR9nRoYgM6q82r8BkkvVn3YpZL+3d3/oyOj6oK6ffao35z1g7M1yLP7zkTHZ/3gbL571uPP+slnz55tWcvWnK+rTo8/uq5Cyn/ufpzv3nbY3f1dSX/RwbEA6CJab0AhCDtQCMIOFIKwA4Ug7EAhFs0U1262r6S41ZK11qLllKV8Kens+KituGrVqvDYbLvpkydPhvUNGzaE9Whs2VLSWdswa92tWbMmrEfqTlGt08rtFs7sQCEIO1AIwg4UgrADhSDsQCEIO1AIwg4Uopg+ezZlMTs+6hdnvehse+Csp5v1k6OxZ/d90003hfWJiYmwfvz48bC+adOmlrVsGevsccl+tqhPn015zqao1t0CvImlpjmzA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhCDtQiEXTZ8+27816sllfNeqLHjt2LDz28OHDYT3qRUv5vO+oz3/69Onw2K1bt4b1bL770aNHw/rmzZtb1rLfWTbPP1tHIDo+W0K77hLcmejfW7fmunNmBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEIumz15nPrqU91Wj+fDr168Pj836yZk6P1t27OjoaFifmZkJ62fOnAnr0c+erTGQ1bM18aPrE7JxZ/8eMtnvvNvbVc8nPbOb2fNmdszM9s25ba2ZvWpmB6v37a/GD6AnFvI0/meSHrjqticlvebuWyW9Vn0OoI+lYXf31yVdvfbQQ5J2Vx/vlrSjw+MC0GHtvkC3wd0nJal63/KPVjPbaWbjZjY+PT3d5t0BqKvrr8a7+y53H3P3seHh4W7fHYAW2g37lJmNSFL1Pp72BaBx7Yb9ZUmPVB8/IumlzgwHQLekfXYze0HS/ZLWmdmEpB9IekbSr8zsMUmHJX2zm4NciGwN8myv7qwPH/VN77nnnvDYwcHBsJ7NX856stHPfuLEifDYbH/1lStXhvVsLn/U686uAcjq2Tz/6GfL9rwfGBiodd/ZuvBN7M+eht3dH25R+lqHxwKgi7hcFigEYQcKQdiBQhB2oBCEHSjEopnimrUysmWHly6NH4qoFbNx48bw2Gyp6KzNk031jKZjZi2mbAntrDWXTQWt0xbMpqFmU1zXrl3bspa1M7PHpYnWWV2c2YFCEHagEIQdKARhBwpB2IFCEHagEIQdKMSi6bPXlfVNo3o23XFkZCSsZ330TDS2rBed9ZOjXrWUX79w6tSplrVsGmg2tuzaiGhsi7GPnuHMDhSCsAOFIOxAIQg7UAjCDhSCsAOFIOxAIeizd0DWL8563VEvWpLOnj0b1qM55dm87axPfsMNN9Q6PhpbNhc+W947m6u/GHvldXBmBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEPTZe2BoaCisZ332c+fOhfVoa+O668Znc/VXrFgR1qO136NtsKV8G+7Vq1eH9ej6hxJ78OmZ3cyeN7NjZrZvzm1Pm9mHZranenuwu8MEUNdCnsb/TNID89z+Y3e/u3p7pbPDAtBpadjd/XVJx3swFgBdVOcFuifMbG/1NH9Nqy8ys51mNm5m49PT0zXuDkAd7Yb9p5JukXS3pElJP2z1he6+y93H3H1seHi4zbsDUFdbYXf3KXe/5O6XJT0raVtnhwWg09oKu5nNXRv5G5L2tfpaAP0h7bOb2QuS7pe0zswmJP1A0v1mdrckl3RI0ne6OMa+l81nX79+fVg/cuRIWM/60VGvfN26deGxWR896uFL0po1LV+ukRT32bN96bP57NnYs99LadKwu/vD89z8XBfGAqCL+K8PKARhBwpB2IFCEHagEIQdKARTXHsga51lU1yzqZ6RqPUl5VNcs9ZbNlU02o46W4Y625I5m76LK3FmBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEPTZeyCbypn1qrNe+ODgYMval770pfDYrJedTTPdvHlzWD9+vPXyhdG4pXxs2eOKK3FmBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEPTZe2B0dDSsZ330bCedqN+cbRed3Xcmu0Yg6tNHc90XUj9//nxYx5U4swOFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UAj67B2Q9Zoz2ZzxycnJtu8/6/Fn68JnY8t64SdOnGhZu3z5cnhstuVydt+4UnpmN7Mbzex3ZnbAzPab2Xer29ea2atmdrB6H2/UDaBRC3kaf1HS9939zyX9paTHzex2SU9Kes3dt0p6rfocQJ9Kw+7uk+7+VvXxKUkHJI1KekjS7urLdkva0a1BAqjvC71AZ2ZbJH1Z0h8kbXD3SWn2PwRJ61scs9PMxs1sfHp6ut5oAbRtwWE3sxWSfi3pe+5+cqHHufsudx9z97FsQgeA7llQ2M3sWs0G/Rfu/pvq5ikzG6nqI5KOdWeIADohbb3ZbF/nOUkH3P1Hc0ovS3pE0jPV+5e6MsJFIJtGeuHChVrfP9r6uG57KmvNZe2xaPrtuXPnwmOzLZmzlmfdluhis5A++32Svi3pbTPbU932lGZD/isze0zSYUnf7M4QAXRCGnZ3/72kVv9Ffq2zwwHQLVwuCxSCsAOFIOxAIQg7UAjCDhSCKa49sGzZsrCeTfU8e/ZsWI/6ydlyy3WnmWa97Oj+ly9fHh6bPW7Zls7ZNQKl4cwOFIKwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAh6LP3gWy55jNnzoT1aD57VJPyXnRWz+bqR338rIef1bNrBOizX4kzO1AIwg4UgrADhSDsQCEIO1AIwg4UgrADhaDP3gFZPzfrk2fztrN+c7T++smT8eY92Xz3bL56ti599P2jNeWlvI9e9xqC0nBmBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEAvZn/1GST+X9GeSLkva5e4/MbOnJf29pOnqS59y91e6NdB+lvVzs35y1i/OetnR/Wd7oNedz75q1aqwHq3tnn3vwcHBsF7ncSnRQi6quSjp++7+lpkNSXrTzF6taj9293/u3vAAdMpC9meflDRZfXzKzA5IGu32wAB01hf6m93Mtkj6sqQ/VDc9YWZ7zex5M1vT4pidZjZuZuPT09PzfQmAHlhw2M1shaRfS/qeu5+U9FNJt0i6W7Nn/h/Od5y773L3MXcfGx4e7sCQAbRjQWE3s2s1G/RfuPtvJMndp9z9krtflvSspG3dGyaAutKw2+y0p+ckHXD3H825fWTOl31D0r7ODw9Apyzk1fj7JH1b0ttmtqe67SlJD5vZ3ZJc0iFJ3+nKCP8EZC2eOlNUJWlqaiqsr169umUtm6KaLWOdLRWdjT2a4pq1HOs+brjSQl6N/72k+f7FFNlTB/5UcQUdUAjCDhSCsAOFIOxAIQg7UAjCDhSCpaQ7IOtF33rrrWH90UcfDetHjx4N61Gf/8477wyPHRoaCuvRFFVJ2r59e1iPpqlm1yfce++9Yf2WW24J69nvpTSc2YFCEHagEIQdKARhBwpB2IFCEHagEIQdKIT1crldM5uW9P6cm9ZJmunZAL6Yfh1bv45LYmzt6uTYNrv7vOu/9TTsn7tzs3F3H2tsAIF+HVu/jktibO3q1dh4Gg8UgrADhWg67Lsavv9Iv46tX8clMbZ29WRsjf7NDqB3mj6zA+gRwg4UopGwm9kDZvbfZvaOmT3ZxBhaMbNDZva2me0xs/GGx/K8mR0zs31zbltrZq+a2cHq/bx77DU0tqfN7MPqsdtjZg82NLYbzex3ZnbAzPab2Xer2xt97IJx9eRx6/nf7Ga2RNL/SPobSROS3pD0sLv/V08H0oKZHZI05u6NX4BhZl+VdFrSz939zuq2f5J03N2fqf6jXOPu/9AnY3ta0ummt/GudisambvNuKQdkv5ODT52wbj+Vj143Jo4s2+T9I67v+vuFyT9UtJDDYyj77n765KOX3XzQ5J2Vx/v1uw/lp5rMba+4O6T7v5W9fEpSZ9tM97oYxeMqyeaCPuopA/mfD6h/trv3SX91szeNLOdTQ9mHhvcfVKa/ccjaX3D47lauo13L121zXjfPHbtbH9eVxNhn28rqX7q/93n7l+RtF3S49XTVSzMgrbx7pV5thnvC+1uf15XE2GfkHTjnM83STrSwDjm5e5HqvfHJL2o/tuKeuqzHXSr98caHs//66dtvOfbZlx98Ng1uf15E2F/Q9JWM7vJzAYkfUvSyw2M43PMbLB64URmNijp6+q/rahflvRI9fEjkl5qcCxX6JdtvFttM66GH7vGtz93956/SXpQs6/I/6+kf2xiDC3GdbOkP1Zv+5sem6QXNPu07lPNPiN6TNINkl6TdLB6v7aPxvZvkt6WtFezwRppaGx/pdk/DfdK2lO9Pdj0YxeMqyePG5fLAoXgCjqgEIQdKARhBwpB2IFCEHagEIQdKARhBwrxfyPE4e3bezwJAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAATw0lEQVR4nO3dXWyU55UH8P/hw3wYMAaDMZ9mEVKCAgU0ISslKqyirZLckCrqqkipWClaGimRWqkXTbIXzWUUbVv1YlWJblDpqpuqUhsFRdFuI1SFVJEQDiKBAAmQGDAY20D4/gpw9sIvkUM85+/OOzPvmOf/kyzbc/zMPPOOj8cz5z3PY+4OEbn3jSl6AiJSH0p2kUQo2UUSoWQXSYSSXSQR4+p5Y21tbd7Z2VnPm6yaqGpx+/btiscCwPXr18P4jRs3wnhk4sSJYXzs2LFhnM2dzS26b+PHjw/HNjU1hfEJEyaEcTOrKDaadXd34/Tp08PeuVzJbmaPAfgVgLEA/svdX4l+vrOzE11dXXlusmJ5S4w3b94sG7t69Wo4liXz559/niseue+++8J4S0tLGI/uNwAcPXo0jEdzb29vD8cuXrw4jLMnjuiPBfsjN2ZMvn96i/pjUiqVysYqvkdmNhbAfwJ4HMAyABvMbFml1ycitZXnz9caAIfd/TN3vwHgDwDWV2daIlJteZJ9HoDjQ77vyS77GjPbZGZdZtY1MDCQ4+ZEJI88yT7ci5JvvDB2983uXnL30qxZs3LcnIjkkSfZewAsGPL9fAAn801HRGolT7LvArDUzBabWROA7wPYVp1piUi1VVx6c/ebZvY8gP/DYOlti7t/XLWZVRkrvbHy1nvvvVc2dvDgwXDsqVOnwvixY8fC+Kefflrx9U+ZMiUcy0pvPT09YZydYzBnzpyKYgCwaNGiMM7KivPnzy8bW716dTh21apVYTxPjb8ouers7v42gLerNBcRqSGdLiuSCCW7SCKU7CKJULKLJELJLpIIJbtIIuraz15LrN574sSJMP7000+H8agvfOHCheFY1qq5ZMmSMP7QQw+F8TNnzoTxyOHDh8N4a2trGF+3bl0YnzlzZtnYpEmTwrHnzp0L45cuXQrju3fvLhvbuXNnOHbt2rVhfMOGDWF88uTJYbyIOrye2UUSoWQXSYSSXSQRSnaRRCjZRRKhZBdJxD1TemOroL744othnK16++yzz5aNTZ8+PRzLWjnnzp0bxtkKrlF57MsvvwzHzpv3jZXEvmbcuPhXpKOjI4zfunWrbIyVpxYsWBDGe3t7w3jU1sxKjt3d3WH80KFDYXz58uVhnK1uWwt6ZhdJhJJdJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUSMqjp7VDc9ffp0OPatt94K46ymG2F1drYTDtv2mG1tHLVLXrlyJdd1Nzc3h/G2trYwHu1ge+3atXAsq0Wz1uCzZ8+WjbEWU9Y+e/HixTDOWq5VZxeRmlGyiyRCyS6SCCW7SCKU7CKJULKLJELJLpKIe6bOfvz48XAsq3uynvJo2eOpU6eGY1lPOdtWOVqOGYh7+dvb28Ox58+fD+Osjs560qP7xmrVrBYe1dGBuJeenVfBzk9gW3yzLaHZ+Q21kCvZzawbwEUAtwDcdPdSNSYlItVXjWf2f3L3+PQ1ESmcXrOLJCJvsjuAv5jZB2a2abgfMLNNZtZlZl0DAwM5b05EKpU32R9299UAHgfwnJl9++4fcPfN7l5y9xJrCBGR2smV7O5+MvvcD+ANAGuqMSkRqb6Kk93Mms1s6p2vAXwHwL5qTUxEqivPu/HtAN7IaqHjAPyPu/9vVWZVgWh7XoDXdFndNRrPep9ZTzjrbWb97i0tLRWPHTMm/nsf9aMDfL3+6ByBpqamcCyro0+YMCGMX758uWyMbRd99erVMM62AGfHPTo/oVbbOVec7O7+GYBvVXEuIlJDKr2JJELJLpIIJbtIIpTsIolQsoskYlS1uEYti0eOHAnHsuWeWRlo//79ZWOzZ88Ox7IWVrYt8sSJE8N4VD5jpTVWFmQlS3bfouWiWesva0tmJarotnt6esKxDHtMWMmyCHpmF0mEkl0kEUp2kUQo2UUSoWQXSYSSXSQRSnaRRIyqOntUdz116lQ4li15zFbReffdd8vGWM22o6MjjEdLZAP52kyjcxMAflzYksesFh6dQ8DuFzsHgNXpjx07VjbW29sbjmXLd7MWVhaPHvNatbjqmV0kEUp2kUQo2UUSoWQXSYSSXSQRSnaRRCjZRRIxqursUW2SLf177ty5MB71PrPrZ0tJszjbHpidAxAtRc36zVk9mNXp2XbVrA6f57bZ+QlRL/6ZM2fCsez8AjY3VsePtghnaxBUSs/sIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SiFFVZ4+wem+pVArjrH85wrZ77uvrC+NRzRXga9pHPeNsa2K2nj7rOWf16AirVbNzI9hjtnTp0rKxXbt2hWPZ7xM792FUrhtvZlvMrN/M9g25bIaZvWNmh7LPrbWdpojkNZJ/438L4LG7LnsBwHZ3Xwpge/a9iDQwmuzuvgPA2bsuXg9ga/b1VgBPVnleIlJllb5B1+7uvQCQfS672ZmZbTKzLjPrGhgYqPDmRCSvmr8b7+6b3b3k7iX2poaI1E6lyd5nZh0AkH3ur96URKQWKk32bQA2Zl9vBPBmdaYjIrVC6+xm9jqAdQDazKwHwM8AvALgj2b2DIBjAL5Xy0neEfX5slp3S0tLGJ83b14Yf+SRR8rGFi1aFI795JNPwjjr+Z4wYULFcbaPOKuTszo8m3vUa8/OAWB7w7Ne/BUrVpSNbd++PRw7Y8aMMD5nzpwwzs6NKAJNdnffUCb0aJXnIiI1pNNlRRKhZBdJhJJdJBFKdpFEKNlFEnHPtLiylkLWDslaGhcvXlw2FpV4AODDDz8M46zVk21dHMVZaW3atGlhPNoOGshXYmJbE0etuwC/b52dnWVja9euDcdevnw5jLPtotlxK4Ke2UUSoWQXSYSSXSQRSnaRRCjZRRKhZBdJhJJdJBGjqs4e1aPZNrcsvn///jCep5Z96tSpMM7q9KzVM2pDZS2u7LiwejITbavMzo1gt83q9NFjxpYW37FjRxhn52Ww+8a2m64FPbOLJELJLpIIJbtIIpTsIolQsoskQskukgglu0giRlWdPcJqsqxv+/333w/jjz5afjHdaLlkgG/ZzJZrZksuR1hf9ZUrV3KNZ+cYREtNs9s+c+ZMGF+yZEkYj+rsra3xxsOsn53V0dny3+wcgVrQM7tIIpTsIolQsoskQskukgglu0gilOwiiVCyiyRiVNXZo5ot6xlnddUjR46E8aeeeqpsjNWiWc/4rFmzwngeeefG4mzN+6gX/+jRo+FY9pi1tbWF8egcALYlM7tfrA5/9uzZMN6Q/exmtsXM+s1s35DLXjazE2a2J/t4orbTFJG8RvJv/G8BPDbM5b9095XZx9vVnZaIVBtNdnffASD+n0REGl6eN+ieN7OPsn/zy764MrNNZtZlZl0DAwM5bk5E8qg02X8NYAmAlQB6Afy83A+6+2Z3L7l7qZZvRIlIrKJkd/c+d7/l7rcB/AbAmupOS0SqraJkN7OOId9+F8C+cj8rIo2B1tnN7HUA6wC0mVkPgJ8BWGdmKwE4gG4AP6zhHL8S1SbZS4TJkyeH8UuXLoXx+fPnl41du3YtHMv2hr9w4UIY7+/vD+PR3KN5jwSr07N6cXd3d9lYdN4EAHR0dIRxNj7C1tNn183q9KxfPTpu7JhW2gtPk93dNwxz8WsV3ZqIFEany4okQskukgglu0gilOwiiVCyiyRiVLW4Rm2HJ0+eDMeyUgqLR+2UBw8eDMey8hUr3bH23aiMxJZ6ZttBs/IXO+7RksqshZWVx8aNq/zXl41l5S32mLHfp4ZscRWRe4OSXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEjKo6e1S7ZDXZvFsTRy2yvb294Vi2HBebG1u2OGpxZUtBs62u2XLPbO7Lli2r+LbZ3Nlt52mBzVtnZ9t455lbpfTMLpIIJbtIIpTsIolQsoskQskukgglu0gilOwiiWioOjvr8Y1qm1988UU4trOzM4w/8MADYbypqalsbO/eveFYtsw1661ubm4O41HNlh1TVquePn16GF+4cGEYj+bO7veUKVPCONsWOTp3gvX5s6XH2WPCtnyudDnoPPTMLpIIJbtIIpTsIolQsoskQskukgglu0gilOwiiWioOjurPUY962x986hODgDLly+vePzx48fDsXPmzAnjrBefbQ88bdq0srEbN26EY9kxj9Z9B3jPeVQrZ9d9/fr1XLcdYfd7xYoVYZytUcDOTygCPVpmtsDM/mpmB8zsYzP7UXb5DDN7x8wOZZ/jFf9FpFAj+dN4E8BP3P1+AP8I4DkzWwbgBQDb3X0pgO3Z9yLSoGiyu3uvu+/Ovr4I4ACAeQDWA9ia/dhWAE/WapIikt/f9aLHzDoBrAKwE0C7u/cCg38QAMwuM2aTmXWZWRd7nSMitTPiZDezKQD+BODH7n5hpOPcfbO7l9y9xBpCRKR2RpTsZjYeg4n+e3f/c3Zxn5l1ZPEOAP21maKIVAMtvdlgjeI1AAfc/RdDQtsAbATwSvb5zbyTybONbd420UWLFoXxaG5sS+WWlpYwzrZNZu2SV69eLRtjrZrRMtTsugFg7ty5YTxqM71wIf4HkZXe2GMe3TYrd86ePeyr0q+wFtlJkyaF8SKMpM7+MIAfANhrZnuyy17CYJL/0cyeAXAMwPdqM0URqQaa7O7+NwDlzkB4tLrTEZFa0emyIolQsoskQskukgglu0gilOwiiWioFlcmqgmzLZdZDb+trS2MR9sLs3owa+Vkc2etnFEbK7tutgQ3a89l1x8dG7btMWtLZucfRHHW4srq5GzLZfaY5zmnpFJ6ZhdJhJJdJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUQ0VJ2d1T6nTp1aNvbggw+GY8+fPx/GWc95VGdnfdmsZsuW62K905GZM2eGcbYtMlsSmfWUnzt3rqIYwJfQZucfROcAsDp59LsG8HMA2GMezb1W2znrmV0kEUp2kUQo2UUSoWQXSYSSXSQRSnaRRCjZRRLRUHV2Jqo/sv5htk44i0f9x2xtdlaLjrZcBnjf9tixY8vGLl++HI5lTp8+HcbZuvPRcWO1bnb+AjtHIOrzZ7fN+s2j8y4AXiuvVS09omd2kUQo2UUSoWQXSYSSXSQRSnaRRCjZRRKhZBdJxEj2Z18A4HcA5gC4DWCzu//KzF4G8G8A7jRjv+Tub9dqokBcT2Z911euXAnjrBYe1WU7OzvDsWxd+RMnToRx1tcd1ZNZzzfrlWdru7P7Fj0ubM15VsuOfh8AXkuPsOPG9q1n505Ec2M1/kpr9CM5qeYmgJ+4+24zmwrgAzN7J4v90t3/o6JbFpG6Gsn+7L0AerOvL5rZAQDzaj0xEamuv+s1u5l1AlgFYGd20fNm9pGZbTGz1jJjNplZl5l1seWXRKR2RpzsZjYFwJ8A/NjdLwD4NYAlAFZi8Jn/58ONc/fN7l5y99KsWbOqMGURqcSIkt3MxmMw0X/v7n8GAHfvc/db7n4bwG8ArKndNEUkL5rsNvjW32sADrj7L4Zc3jHkx74LYF/1pyci1TKSd+MfBvADAHvNbE922UsANpjZSgAOoBvAD2sywyGicsjs2bPDsaxUwlpcozIPe3nS2jrs2xlfYctYs2WJo1INKwH19fWFcaa5uTmMR63HrNzJttFmZb/o+ln5ipVyGXZcitiyeSTvxv8NwHBHpqY1dRGpLp1BJ5IIJbtIIpTsIolQsoskQskukgglu0giRtVS0lHN9v777w/HLl26NIyzumhUZ3/11VfDsXmXHWai9l1Wy2ZbWbPzD9j157lutjw4a3GNtlVmrb3sfrHlw/PMTVs2i0guSnaRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEmH17Ks1swEAR4dc1AYg3hO4OI06t0adF6C5Vaqac1vk7sMusFDXZP/GjZt1uXupsAkEGnVujTovQHOrVL3mpn/jRRKhZBdJRNHJvrng24806twadV6A5lapusyt0NfsIlI/RT+zi0idKNlFElFIspvZY2b2iZkdNrMXiphDOWbWbWZ7zWyPmXUVPJctZtZvZvuGXDbDzN4xs0PZ53hR+vrO7WUzO5Eduz1m9kRBc1tgZn81swNm9rGZ/Si7vNBjF8yrLset7q/ZzWwsgE8B/DOAHgC7AGxw9/11nUgZZtYNoOTuhZ+AYWbfBnAJwO/c/YHsslcBnHX3V7I/lK3u/tMGmdvLAC4VvY13tltRx9BtxgE8CeBfUeCxC+b1L6jDcSvimX0NgMPu/pm73wDwBwDrC5hHw3P3HQDO3nXxegBbs6+3YvCXpe7KzK0huHuvu+/Ovr4I4M4244Ueu2BedVFEss8DcHzI9z1orP3eHcBfzOwDM9tU9GSG0e7uvcDgLw+AeN+r+qPbeNfTXduMN8yxq2T787yKSPbhFthqpPrfw+6+GsDjAJ7L/l2VkRnRNt71Msw24w2h0u3P8yoi2XsALBjy/XwAJwuYx7Dc/WT2uR/AG2i8raj77uygm33uL3g+X2mkbbyH22YcDXDsitz+vIhk3wVgqZktNrMmAN8HsK2AeXyDmTVnb5zAzJoBfAeNtxX1NgAbs683AnizwLl8TaNs411um3EUfOwK3/7c3ev+AeAJDL4jfwTAvxcxhzLz+gcAH2YfHxc9NwCvY/Dfui8x+B/RMwBmAtgO4FD2eUYDze2/AewF8BEGE6ujoLk9gsGXhh8B2JN9PFH0sQvmVZfjptNlRRKhM+hEEqFkF0mEkl0kEUp2kUQo2UUSoWQXSYSSXSQR/w9daecCm4ulBAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "for i in range(sample.shape[0]):\n", + " plt.imshow(sample[i].view(28, 28), cmap='Greys')\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 262, + "metadata": {}, + "outputs": [], + "source": [ + "class PlanarFlow(nn.Module):\n", + " def __init__(self, size=1, init_sigma=0.01):\n", + " super().__init__()\n", + " self.u = nn.Parameter(torch.randn(size, ).normal_(0, 0.01))\n", + " self.w = nn.Parameter(torch.randn(size, ).normal_(0, 0.01))\n", + " self.b = nn.Parameter(torch.zeros(size, 1))\n", + " \n", + " @property\n", + " def normalized_u(self):\n", + " \"\"\"\n", + " Needed for invertibility condition.\n", + " \n", + " See Appendix A.1\n", + " Rezende et al. Variational Inference with Normalizing Flows\n", + " https://arxiv.org/pdf/1505.05770.pdf\n", + " \"\"\"\n", + " def m(x):\n", + " return -1 + torch.log(1 + torch.exp(x))\n", + " wtu = self.w @ self.u.t()\n", + " return self.u + (m(wtu) - wtu) * self.w / (self.w @ self.w.t())\n", + " \n", + " def psi(self, z):\n", + " \"\"\"\n", + " ψ(z) =h′(w^tz+b)w\n", + " \n", + " See eq(11)\n", + " Rezende et al. Variational Inference with Normalizing Flows\n", + " https://arxiv.org/pdf/1505.05770.pdf\n", + " \"\"\"\n", + " return (self.h_prime(z @ self.w.t() + self.b).t() @ self.w.t())\n", + " \n", + " def h(self, x):\n", + " return torch.tanh(x)\n", + " \n", + " def h_prime(self, z):\n", + " return 1 - torch.tanh(z)**2\n", + " \n", + " def forward(self, z):\n", + " if isinstance(z, tuple):\n", + " z, accumulating_log_abs_det = z\n", + " else:\n", + " z, accumulating_log_abs_det = z, 0\n", + " \n", + " psi = self.psi(z)\n", + " u = self.normalized_u\n", + " print(psi.shape, u.shape)\n", + " # determinant of jacobian\n", + " det = (1 + psi @ u.t()).unsqueeze(1)\n", + " \n", + " # log |det Jac|\n", + " log_abs_det = torch.log(torch.abs(det) + 1e-6)\n", + " \n", + " fz = z + u + self.h((z @ self.w.t()).unsqueeze(1) + self.b)\n", + " print(fz.shape)\n", + " \n", + " return fz, log_abs_det + accumulating_log_abs_det\n", + "\n", + "class Planar(nn.Module):\n", + " \"\"\"\n", + " PyTorch implementation of planar flows as presented in \"Variational Inference with Normalizing Flows\"\n", + " by Danilo Jimenez Rezende, Shakir Mohamed. Model assumes amortized flow parameters.\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + "\n", + " super(Planar, self).__init__()\n", + "\n", + " self.h = nn.Tanh()\n", + " self.softplus = nn.Softplus()\n", + "\n", + " def der_h(self, x):\n", + " \"\"\" Derivative of tanh \"\"\"\n", + "\n", + " return 1 - self.h(x) ** 2\n", + "\n", + " def forward(self, zk, u, w, b):\n", + " \"\"\"\n", + " Forward pass. Assumes amortized u, w and b. Conditions on diagonals of u and w for invertibility\n", + " will be be satisfied inside this function. Computes the following transformation:\n", + " z' = z + u h( w^T z + b)\n", + " or actually\n", + " z'^T = z^T + h(z^T w + b)u^T\n", + " Assumes the following input shapes:\n", + " shape u = (batch_size, z_size, 1)\n", + " shape w = (batch_size, 1, z_size)\n", + " shape b = (batch_size, 1, 1)\n", + " shape z = (batch_size, z_size).\n", + " \"\"\"\n", + "\n", + " zk = zk.unsqueeze(2)\n", + "\n", + " # reparameterize u such that the flow becomes invertible (see appendix paper)\n", + " uw = torch.bmm(w, u)\n", + " m_uw = -1. + self.softplus(uw)\n", + " w_norm_sq = torch.sum(w ** 2, dim=2, keepdim=True)\n", + " u_hat = u + ((m_uw - uw) * w.transpose(2, 1) / w_norm_sq)\n", + "\n", + " # compute flow with u_hat\n", + " wzb = torch.bmm(w, zk) + b\n", + " z = zk + u_hat * self.h(wzb)\n", + " z = z.squeeze(2)\n", + "\n", + " # compute logdetJ\n", + " psi = w * self.der_h(wzb)\n", + " log_det_jacobian = torch.log(torch.abs(1 + torch.bmm(psi, u_hat)))\n", + " log_det_jacobian = log_det_jacobian.squeeze(2).squeeze(1)\n", + "\n", + " return z, log_det_jacobian \n", + "\n", + "class FlowVAE(VAE):\n", + " def __init__(self, latent_size, n_flows=20):\n", + " super().__init__(latent_size)\n", + " \n", + " self.flow = nn.Sequential(*[\n", + " Planar(latent_size) for _ in range(n_flows)\n", + " ]) \n", + " \n", + " def forward(self, x):\n", + " z_theta = self.encoder(x)\n", + " mu0 = z_theta[:, :self.latent_size]\n", + " log_var0 = z_theta[:, self.latent_size:]\n", + " z0 = self.reparameterize(mu0, log_var0)\n", + " zk, accumulating_log_abs_det = self.flow(z0)\n", + " self.decoder(zk)\n", + " \n", + " return z0, zk, self.decoder(zk), accumulating_log_abs_det, mu0, log_var0" + ] + }, + { + "cell_type": "code", + "execution_count": 263, + "metadata": {}, + "outputs": [], + "source": [ + "def log_normal_standard(x, average=False, reduce=True, dim=None):\n", + " log_norm = -0.5 * x * x\n", + "\n", + " if reduce:\n", + " if average:\n", + " return torch.mean(log_norm, dim)\n", + " else:\n", + " return torch.sum(log_norm, dim)\n", + " else:\n", + " return log_norm\n", + " \n", + "def log_normal_diag(x, mean, log_var, average=False, reduce=True, dim=None):\n", + " log_norm = -0.5 * (log_var + (x - mean) * (x - mean) * log_var.exp().reciprocal())\n", + " if reduce:\n", + " if average:\n", + " return torch.mean(log_norm, dim)\n", + " else:\n", + " return torch.sum(log_norm, dim)\n", + " else:\n", + " return log_norm\n", + "\n", + " \n", + "def det_loss(reconstruction_x, x, mu, log_var, z_0, z_k, ldj, beta=1.):\n", + " \"\"\"\n", + " :param z_mu: mean of z_0\n", + " :param z_var: variance of z_0\n", + " :param z_0: first stochastic latent variable\n", + " :param z_k: last stochastic latent variable\n", + " :param ldj: log det jacobian\n", + " \"\"\"\n", + "\n", + " reconstruction_function = nn.BCELoss(reduction='sum')\n", + "\n", + " batch_size = x.size(0)\n", + "\n", + " # - N E_q0 [ ln p(x|z_k) ]\n", + " likelihood = reconstruction_function(reconstruction_x, x)\n", + "\n", + " # ln p(z_k) (not averaged)\n", + " log_p_zk = log_normal_standard(z_k, dim=1)\n", + " # ln q(z_0) (not averaged)\n", + " log_q_z0 = log_normal_diag(z_0, mean=mu, log_var=log_var, dim=1)\n", + " # N E_q0[ ln q(z_0) - ln p(z_k) ]\n", + " summed_logs = torch.sum(log_q_z0 - log_p_zk)\n", + "\n", + " # sum over batches\n", + " summed_ldj = torch.sum(ldj)\n", + "\n", + " # ldj = N E_q_z0[\\sum_k log |det dz_k/dz_k-1| ]\n", + " kl = (summed_logs - summed_ldj)\n", + " loss = likelihood + beta * kl\n", + "\n", + " loss = loss / float(batch_size)\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 264, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([128]) torch.Size([250])\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "invalid argument 2: sizes do not match at /opt/conda/conda-bld/pytorch_1565272271120/work/aten/src/THC/generic/THCTensorMathBlas.cu:20", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m11\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cuda'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch, m, device, optimizer)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mz0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreconstruction_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccumulating_log_abs_det\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_var0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m loss = det_loss(x=x, \n", + "\u001b[0;32m/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 546\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 548\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0mlog_var0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mz_theta\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlatent_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mz0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreparameterize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmu0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_var0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mzk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccumulating_log_abs_det\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 546\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 548\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 93\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 546\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 548\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, z)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;31m# determinant of jacobian\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mdet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mpsi\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;31m# log |det Jac|\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: invalid argument 2: sizes do not match at /opt/conda/conda-bld/pytorch_1565272271120/work/aten/src/THC/generic/THCTensorMathBlas.cu:20" + ] + } + ], + "source": [ + "latent_size = 250\n", + "m = FlowVAE(latent_size)\n", + "m.cuda()\n", + "\n", + "def train(epoch, m, device, optimizer):\n", + " m.train()\n", + " train_loss = 0\n", + " for batch_idx, (x, _) in enumerate(dl):\n", + " x = x.to(device).view(-1, 784)\n", + " \n", + " optim.zero_grad()\n", + " z0, zk, reconstruction_x, accumulating_log_abs_det, mu0, log_var0 = m(x)\n", + " \n", + " loss = det_loss(x=x, \n", + " reconstruction_x=reconstruction_x, \n", + " mu=mu0,\n", + " log_var=log_var0,\n", + " ldj=accumulating_log_abs_det,\n", + " z_0=z0,\n", + " z_k=zk\n", + " )\n", + " \n", + " recon_batch, mu, logvar = m(x)\n", + " loss = loss_function(recon_batch, x, mu, logvar)\n", + " loss.backward()\n", + " \n", + " train_loss += loss.item()\n", + " optimizer.step()\n", + " if batch_idx % 150 == 0:\n", + " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " epoch, batch_idx * len(x), len(dl.dataset),\n", + " 100. * batch_idx / len(dl),\n", + " loss.item() / len(x)))\n", + "\n", + " print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(dl.dataset)))\n", + "\n", + "\n", + "optim = torch.optim.Adam(m.parameters(), 1e-3)\n", + "\n", + "for epoch in range(1, 11):\n", + " train(epoch, m, 'cuda', optim)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}