diff --git a/physgrad-comparison.ipynb b/physgrad-comparison.ipynb index 15f580c..2f9898c 100644 --- a/physgrad-comparison.ipynb +++ b/physgrad-comparison.ipynb @@ -44,7 +44,7 @@ "\n", "## 3 Spaces\n", "\n", - "In order to understand the following examples, it's important to keep in mind that we'll deal with mappings between the three _spaces_ we've introduced here:\n", + "In order to understand the following examples, it's important to keep in mind that we're dealing with mappings between the three _spaces_ we've introduced here:\n", "$\\mathbf{x}$, $\\mathbf{z}$ and $L$. A regular forward pass maps an\n", "$\\mathbf{x}$ to $L$, while for the optimization we'll need to associate values\n", "and changes in $L$ with positions in $\\mathbf{x}$. While doing this, it will \n", @@ -68,24 +68,33 @@ "## Implementation\n", "\n", "For this example we'll use the [JAX framework](https://github.com/google/jax), which represents a nice alternative to pytorch and tensorflow for efficiently working with differentiable functions.\n", - "\n", "JAX also has a nice numpy wrapper that implements most of numpy's functions. Below we'll use this wrapper as `np`, and the _original_ numpy as `onp`.\n", - "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as np\n", + "import numpy as onp\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "We'll start by defining the $\\mathbf{z}$ and $L$ functions, together with a single composite function `fun` which calls L and z. Having a single native python function is necessary for many of the JAX operations." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 17, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -103,17 +112,12 @@ " DeviceArray(90., dtype=float32))" ] }, - "execution_count": 1, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "import jax\n", - "import jax.numpy as np\n", - "import numpy as onp\n", - "\n", - "\n", "# \"physics\" function z\n", "def fun_z(x):\n", " return np.array( [x[0], x[1]*x[1]] )\n", @@ -139,7 +143,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now we can evaluate the derivatives of our function via `jax.grad`. E.g., `jax.grad(fun_L)(fun_z(x))` evaluates the Jacobian $\\partial L / \\partial z$. The cell below evaluates this and a few variants, together with a sanity check for the inverse of the Jacobian of $\\mathbf{z}$:" + "Now we can evaluate the derivatives of our function via `jax.grad`. E.g., `jax.grad(fun_L)(fun_z(x))` evaluates the Jacobian $\\partial L / \\partial \\mathbf{z}$. The cell below evaluates this and a few variants, together with a sanity check for the inverse of the Jacobian of $\\mathbf{z}$:" ] }, { @@ -190,7 +194,13 @@ "source": [ "The last line is worth a closer look: here we print the gradient $\\partial L / \\partial \\mathbf{x}$ at our initial position. And while we know that we should just move diagonally towards the origin (with the zero vector being the minimizer), this gradient is not very diagonal - it has a strongly dominant component along $x_1$ with an entry of 108.\n", "\n", - "Let's see how the different methods cope with this situation. We'll compare the first order method _gradient descent_ (i.e., regular, non-stochastic, \"steepest gradient descent\"), _Newton's method_ as a representative of the second order methods, and _physical gradients_.\n" + "Let's see how the different methods cope with this situation. We'll compare \n", + "\n", + "* the first order method _gradient descent_ (i.e., regular, non-stochastic, \"steepest gradient descent\"), \n", + "\n", + "* _Newton's method_ as a representative of the second order methods, \n", + "\n", + "* and _physical gradients_.\n" ] }, { @@ -326,9 +336,9 @@ "\n", "Hence, in addition to the same gradient as for GD, we now need to evaluate and invert the Hessian of $\\frac{\\partial^2 L }{ \\partial \\mathbf{x}^2 }$.\n", "\n", - "This is quite straightforward in JAX: we can call `jax.jacobian` two times, and then use the JAX version of `linalg.inv` to invert the matrix.\n", + "This is quite straightforward in JAX: we can call `jax.jacobian` two times, and then use the JAX version of `linalg.inv` to invert the resulting matrix.\n", "\n", - "For the optimization with Newton's method we'll use a larger step size of $\\eta =1/3$. For this example and the following one, we've chosen the stepsize such that the magnitude of the first update step is roughly the same as the one of GD. In this way, we can compare the trajectories of all three methods relative to each other. Note that this is by no means meant to illustrate or compare the stability of the methods in this example. Stability and upper limits for $\\eta$ are separate topics. Here we're focusing on convergence properties.\n", + "For the optimization with Newton's method we'll use a larger step size of $\\eta =1/3$. For this example and the following one, we've chosen the stepsize such that the magnitude of the first update step is roughly the same as the one of GD. In this way, we can compare the trajectories of all three methods relative to each other. Note that this is by no means meant to illustrate or compare the stability of the methods here. Stability and upper limits for $\\eta$ are separate topics. Here we're focusing on convergence properties.\n", "\n", "In the next cell, we apply the Newton updates ten times starting from the same initial guess:" ] @@ -450,7 +460,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -509,7 +519,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -518,7 +528,7 @@ "Text(0, 0.5, 'x1')" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, @@ -559,7 +569,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -599,16 +609,16 @@ "\n", "## Z Space\n", "\n", - "To understand the behavior and differences of the methods here, it's important to keep in mind that we're not dealing with a black box that maps between $\\mathbf{x}$ and $L$, but rather there are spaces inbetween that matter. In our case, we only have a single $\\mathbf{z}$ space, but for DL settings, we might have a large number of latent spaces, over which we have a certain amount of control.\n", + "To understand the behavior and differences of the methods here, it's important to keep in mind that we're not dealing with a black box that maps between $\\mathbf{x}$ and $L$, but rather there are spaces inbetween that matter. In our case, we only have a single $\\mathbf{z}$ space, but for DL settings, we might have a large number of latent spaces, over which we have a certain amount of control. We will return to NNs soon, but for now let's focus on $\\mathbf{z}$. \n", "\n", - "We will return to NNs soon, but for now let's focus on $\\mathbf{z}$. One first thing to note is that for PG, we explicitly map from $L$ to $\\mathbf{z}$, and then continue with a mapping to $\\mathbf{x}$. Thus we already obtained the trajectory in $\\mathbf{z}$ space, and not conincidentally, we already stored it in the `historyPGz` list above.\n", + "A first thing to note is that for PG, we explicitly map from $L$ to $\\mathbf{z}$, and then continue with a mapping to $\\mathbf{x}$. Thus we already obtained the trajectory in $\\mathbf{z}$ space, and not conincidentally, we already stored it in the `historyPGz` list above.\n", "\n", "Let's directly take a look what PG did in $\\mathbf{z}$ space:" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -617,7 +627,7 @@ "Text(0, 0.5, 'z1')" ] }, - "execution_count": 14, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, @@ -649,22 +659,87 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For PG we're making explicit steps in $\\mathbf{z}$ space, which progress in a straight diagonal line to the origin. (Note that in $\\mathbf{z}$ space the origin is likewise the solution.)\n", + "For PG we're making explicit steps in $\\mathbf{z}$ space, which progress in a straight diagonal line to the origin (which is likewise the solution in $\\mathbf{z}$ space).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Interestingly, neither GD nor Newton's method give us information about progress in intermediate spaces (like the $\\mathbf{z}$ space). \n", "\n", - "Interestingly, neither GD nor Newton's method give us information about progress in intermediate spaces like $\\mathbf{z}$: for GD we're concatenating the Jacobians, so we're moving in directions that locally should decrease the loss. However, the $\\mathbf{z}$ position is influenced by $\\mathbf{x}$, and hence we don't know where we end up in $\\mathbf{z}$ space until we have the definite point in $\\mathbf{x}$ space.\n", + "For **GD** we're concatenating the Jacobians, so we're moving in directions that locally should decrease the loss. However, the $\\mathbf{z}$ position is influenced by $\\mathbf{x}$, and hence we don't know where we end up in $\\mathbf{z}$ space until we have the definite point there. (For NNs in general we won't know at which latent-space points we end up after a GD update until we've actually computed all updated weights.)\n", "\n", - "With PGs we do not have this problem, as we directly map points in $\\mathbf{z}$ to $\\mathbf{x}$ via an inverse function. Hence we know eactly where we started in $\\mathbf{z}$ space, as this position is crucial for the inversion." + "More specifically, we have an update $-\\eta \\frac{\\partial L}{\\partial \\mathbf{x}}$ for GD, which means we arrive at $\\mathbf{z}(\\mathbf{x} -\\eta \\frac{\\partial L}{\\partial \\mathbf{x}})$ in $\\mathbf{z}$ space. A Taylor expansion with \n", + "$h = \\eta \\frac{\\partial L}{\\partial \\mathbf{x}}$ yields \n", + "\n", + "$\n", + "\\quad\n", + "\\mathbf{z}(\\mathbf{x} - h) = \n", + "\\mathbf{z}(\\mathbf{x}) - h \\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{x}} + \\mathcal{O}( h^2 )\n", + "= \\mathbf{z}(x) - \\eta \\frac{\\partial L}{\\partial \\mathbf{z}} (\\frac{\\partial \\mathbf{z}}{\\partial x})^2 + \\mathcal{O}( h^2 )\n", + "$.\n", + "\n", + "And $\\frac{\\partial L}{\\partial \\mathbf{z}} (\\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{x}})^2$ clearly differs from the step $\\frac{\\partial L}{\\partial \\mathbf{z}}$ we would compute during the back-propagation pass in GD for $\\mathbf{z}$.\n", + "\n", + "**Newton's method** does not fare much better: we compute first-order derivatives like for GD, and the second-order derivatives for the Hessian for the full process. But since both are approximations, the actual intermediate states resulting from an update step are unknown until the full chain is evaluated. In the _Consistency in function compositions_ paragraph for Newton's method in {doc}`physgrad` the squared $\\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{x}}$ term for the Hessian already indicated this dependency.\n", + "\n", + "With **PGs** we do not have this problem: PGs can directly map points in $\\mathbf{z}$ to $\\mathbf{x}$ via the inverse function. Hence we know eactly where we started in $\\mathbf{z}$ space, as this position is crucial for evaluating the inverse.\n", + "\n", + "In the simple setting of this section, we only have a single latent space, and we already stored all values in $\\mathbf{x}$ space during the optimization (in the `history` lists). Hence, now we can go back and re-evaluate `fun_z` to obtain the positions in $\\mathbf{z}$ space." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "historyGDz = onp.asarray(historyGDz)\n", - "historyNtz = onp.asarray(historyNtz)\n", + "x = np.asarray([3.,3.])\n", + "eta = 0.01\n", + "historyGDz = []\n", + "historyNtz = []\n", "\n", + "for i in range(1,10):\n", + " historyGDz.append(fun_z(historyGD[i]))\n", + " historyNtz.append(fun_z(historyNt[i]))\n", + "\n", + "historyGDz = onp.asarray(historyGDz)\n", + "historyNtz = onp.asarray(historyNtz)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'z1')" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAGDCAYAAAA26pu1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlkElEQVR4nO3df5zcVX3v8dcnm2SjkF2FhQQbKbhC0luwYH6Ae70EIpDa670Kiliu95rGtmrF22oftdLij/aay/W2dSkNahpBbZuo9NrSVisGCwmJEGjRyoISdGuAFEiyxp1lhewuu5/7x5kh311mktndmTnfM/N+Ph7fx+R75nx3PplMPnvmfM8Pc3dERCRdc2IHICIis6NELiKSOCVyEZHEKZGLiCROiVxEJHFK5CIiiVMiFxFJnBK5iEjilMhFRBKnRC4ikjglchGRxCmRi4gkTolcRCRxSuSSFDM7zcy80nGMaxea2fVmttfMRszsgJndbmavztTZbmYPmtlyM7vbzJ41sx+Z2bun/Kz5ZvaHZna/mRXM7KdmttPMLirzunPM7DfNrM/MDpvZQTO7zcxWTKn39uLPe9bMDpnZl8zs5bN9z6T5mZaxlZSY2XHAZVOK5wG9wKi7n3yUa7cAbwE2At8DTgReC3zZ3bcU62wHzgDmArcAjwBvLdZ7p7vfXKzXBTwAfBH4AbAQeCfwCmCVu/9r5nU/B6wDvg58o/iz/xPwTXffWKzz+8D/Kr7mDuAk4H3AMHCuuw9O422SVuPuOnQkfQA3As8BFx2j3iCw8Rh1tgMOfCBTNh/4DrAfmFcsawPmT7n2JcBTwE2ZsouKP+9Py7xWqSH1s8X4f2/K82cBY1PLdeiYeqhrRZJmZv8D+A3gg+5+5zGqDwLnmdnLjlHvOWBT6cTdR4vnJwPLi2XjxfJS18kJhJb2vwCvzvysNxMS+R9MfRF3L30dvpzQzXmLmXWVDsIvhR8QfhmIVKRELskys3OAzwBfdPdPVnHJBwmt3MfN7D4z+5iZvaJMvSfc/adTyh4pPp6Wef13mNkDwGHgx8BB4D8DnZnruos/79BR4joDMELSPjjl+DnCLxCRiubGDkBkJszspcBXCAn2V6u5xt1vMbOdhD72S4HfAX7XzC53969P8/XfDnweuBX4I+AAMA5cQ0je0zGH0Gp/ffFnTDU8zZ8nLUaJXJJjZnOALYQ+6Yvd/Zlqr3X3J4FPAZ8ys5OBbwO/T7gRWfIyMztuSqv8zOLj3uLjW4B/Ay7PdJFgZlO7UPqBtWZ2wlFa5f2EFvmP3P2RCnVEKlLXiqToo8Ba4Jfd/UfVXGBmbWaW7fLA3Q8ATwDtU6rPBd6VuXZ+8fwgcH+xuNRytky984DXTPlZXynW+WiZmErX/k3x5300U/Z8HTM78Vh/P2ltapFLUszsbODDwF3AycUujue5+19VuHQhsM/M/h/wXUJ3xcXASuC3p9R9gtDlchqh6+ZK4Bzg1919rFjnq4SblH9rZl8DTgfeTRjWeHwmnjvN7C+B/2lmZwC3ERpQ/wm4kzCKpt/MrgWuA04zs1uBp4s/8zLgz4E/rvItkhakceSSFDO7kJAAy3J3K1debFV/nNA3/gpCMv0hsMndP52ptx3oAt4B/BlwLmHY4R+5+42ZegZ8iNBSX0xI4B8GrgAudPfTMnXbgPdzZJx5gTC65Vp3/3am3uXFeucWix4H/gm4QV0ucjRK5CIZpUTu7mfFjkWkWuojFxFJnBK5iEjilMhFRBKnPnIRkcSpRS4ikjglchGRxCU9Iag4lvdlhMkTIiLNYiFhsbWq+r6TTuSEJL4vdhAiInWwBPj3aiqmnsifBnj88cfp6OiIHYuIyKwNDQ3x8pe/HKbR05B6Igego6NDiVxEWpZudoqIJE6JXEQkcUrkIiKJUyIXEUmcErmISOKUyEVakdZYaipK5CKtZmIcblsRHqUpKJGLtJqDO6HQBwd3xY5EaqQpJgSJyDEcHoCxQvjzw70wMQZ7euHFS0LZvE5Y0BUvPpkVJXKRVvDQBnhkI8w/ASZGQ9n+HbCtB0YPwZlXw/LeuDHKjKlrRaQVLO+FVZvA2mBsMJSNDYLNhVWblcQTFzWRm9leM/Myx40x4xJpSt3rYdGayWWL10D3uijhSO3E7lpZCbRlzs8Cbgf+Ok44Ik3MHQ7sgPYuOL4bhvth//ZQbhY7OpmFqC1ydz/o7k+VDuANQD+wI2ZcIk1psA+eG4aerbB2N/RsCeeDfbEjk1nKzebLZjYfeAL4pLv/7wp12oH2TNFCYF+hUNAytiLHMj4CEyMwL/N/ZbQAbQugrb3yddJQQ0NDdHZ2AnS6+1A11+TpZuebgJcAnz9KnWuAQubQ7kAi1Wprn5zEAeZ3Kok3gTwl8ncCX3f3J45S5zqgM3MsaURgIiJ5FvtmJwBm9rPAxcDlR6vn7iPASOa6OkcmIpJ/eWmR/wpwAPha7EBERFITPZGb2RxCIv+Cuz8XOx4RkdRET+SELpVTgZtjByIikqLofeTuvg1QZ7eIyAzloUUuIiKzoEQuIpI4JXIRkcQpkYuIJE6JXEQkcUrkIiKJUyIXEUmcErmISOKUyEVEEqdELiKSOCVyEZHEKZGLiCROiVxEJHFK5CIiiVMiFxFJnBK5iFTmHjsCqYISuYiUNz4OK1aER8k1JXIRKW/nTujrg127YkcixxB9qzcRyZGBASgUwp97e2FsLDwuWRLKOjuhqytefFKWWuQicsSGDbBsGfT0wF13hbIdO8L5smXheckdJXIROaK3FzZtgrY2GBwMZYODMHcubN4cnpfcUSIXkcnWr4c1ayaXrVkD69ZFCUeOTX3kIjKZe+hO6eqC7m7o74ft20O5WezopAy1yEVksr4+GB6GrVth927YsiWc9/XFjkwqME94wL+ZdQCFQqFAR0dH7HBEmsPISDiy/6cKBViwANrb48XVIoaGhujs7ATodPehaq5R14qITNbe/sKEHRKL5JS6VkREEhc9kZvZz5jZX5nZj83sWTPrM7MVseMSEUlF1K4VM3sp8C3gTuD1wEHgDOAnMeMSEUlJ7D7y3wUed/dfyZT9KFYwIiIpit218l+BfzGzvzazA2b2HTP7tUqVzazdzDpKB7CwcaGKiORT7ET+CuA9wA+AtcCngRvM7B0V6l8DFDLHvkYEKSKSZ1HHkZvZKPAv7t6TKbsBWOnurylTvx3IjotaCOzTOHIRaRYpjiN/EvjelLLvA28uV9ndR4CR0rlpurCISPSulW8BS6eUnQk8GiEWEZEkxU7kvcD5ZvZ7ZvZKM7sK+HXgxshxiYgkI2oid/d/Bi4Dfhl4EPgw8FvuviVmXCIiKYndR467fxX4auw4RERSFbtrRUREZkmJXEQkcUrkIiKJUyKX1pTwhioiUymRS+uZGIfbVoRHkSagRC6t5+BOKPTBwV2xIxGpiejDD0Ua4vAAjBXCnx/uhYkx2NMLL14SyuZ1woKuePGJzIISubSGhzbAIxth/gkwMRrK9u+AbT0wegjOvBqW98aNUWSG1LUirWF5L6zaBNYGY4OhbGwQbC6s2qwkLklTIpfW0b0eFq2ZXLZ4DXSvixKOSK2oa0Vahzsc2AHtXXB8Nwz3w/7toVxLIkvC1CKX1jHYB88NQ89WWLsberaE88G+2JGJzErUHYJmq7hvZ0E7BElVxkdgYgTmZT4rowVoWwBt7ZWvE2mgFHcIEmmctvYXJuz5nXFiEakhda2IiCROiVxEJHFK5CIiiVMiFxFJnBK5iEjilMhFRBKnRC4ikjglchGRxCmRi4gkTolcRCRxSuQiIolTIhcRSZwSuYhI4pTIRUQSFzWRm9nHzMynHA/HjElEJDV5WI/8IeDizPlzsQIREUlRHhL5c+7+VOwgRERSlYc+8jPM7Akz+zcz22Jmp1aqaGbtZtZROoCFDYxTRCSXYifye4F1wC8C7wFOB3aaWaUEfQ1QyBz7GhCjiEiu5WrzZTN7CfAo8AF3v6nM8+1AdtPFhcA+bb4sIs0i+c2X3X3QzB4BXlnh+RFgpHRuZo0KTUQkt2J3rUxiZscD3cCTsWMREUlF7HHkf2xmq83sNDPrAf4WGAe+GDMuEZGUxO5aWUJI2icCB4FdwPnufjBqVCIiCYmayN39bTFfX0SkGeSqj1xERKZPiVxEJHFK5CIiiVMiFxFJnBK5iEjilMhFRBKnRC4ikjglchGRxCmRi4gkTolcRCRxSuRSOzla216klSiRS21MjMNtK8Jjs9IvKskpJXKpjYM7odAHB3fFjqQ+xsdhxYrwKJIzSuQyc4cH4On+cDzcCxNjsKf3SNnhgdgR1s7OndDXB7ua9BeVJC32euSSsoc2wCMbYf4JMDEayvbvgG09MHoIzrwalvfGjXE2BgagUAh/7u2FsbHwuGRJKOvshK6uePGJFKlFLjO3vBdWbQJrg7HBUDY2CDYXVm1OO4kDbNgAy5ZBTw/cdVco27EjnC9bFp4XyQElcpmd7vWwaM3kssVroHtdlHBqqrcXNm2CtjYYHAxlg4Mwdy5s3hyeF8kBJXKZHXc4sAPau+DE88Lj/u3NM8Jj/XpYM+UX1Zo1sG5dlHBEylEil9kZ7IPnhqFnK6zdDT1bwvlgX+zIasM9dKd0dcF554XH7dub5xeVNAUlcpmdjqXwpkfhlEvC+SmXwhv3hvJm0NcHw8OwdSvs3g1btoTzvib5RSVNwTzhloWZdQCFQqFAR0dH7HCkGY2MhCP7+SoUYMECaG+PF5c0raGhITo7OwE63X2omms0/FDkaNrbX5iww38ykdxQ14qISOKUyEVEEqdELiKSOCVyEZHEKZGLiCROiVxEJHG5SeRm9iEzczO7PnYsIiIpyUUiN7OVwLuAB2LHIiKSmuiJ3MyOB7YAvwb8JHI4IiLJiZ7IgRuBr7n7N49V0czazayjdAAL6x+eiEi+RZ2ib2ZvA14NrKzykmuAj9YvIhGR9ERrkZvZy4E/Bf6bux+u8rLrgM7MsaRO4YmIJCNmi3w5cDLwbTMrlbUBF5jZ1UC7u0/astzdR4CR0nnmOhGRlhUzkf8TcPaUss8BDwOfmJrERUSkvGiJ3N2fBh7MlpnZT4Efu/uD5a8SEZGp8jBqRUREZiFXG0u4+4WxYxARSY1a5CIiddKonTSVyEVE6mB8HFasCI/1pkQuIlIHO3dCXx/s2lX/18pVH7mISMoGBqBQCH/u7YWxsfC4pDh1sbMTurpq/7pqkYuI1MiGDbBsGfT0wF13hbIdO8L5smXh+XqoaSI3s18wM03kEZGW1NsLmzZBWxsMDoaywUGYOxc2bw7P10M9WuSaNy8iLWv9elizZnLZmjWwbl39XnNafeRm9jfHqNIJNGjAjRyTO2g9GpGGcg/dKV1d0N0N/f2wfXt9/ztOt0X+X4AFQKHCMVzT6GTmJsbhthXhUUQapq8Phodh61bYvRu2bAnnfX31e83pjlr5PvAVd7+p3JNmdg7whtkGJTVwcCcU+uDgLli0OnY0Ii1j6VJ49FHo6Ajnl14Ke/fCggX1e83pJvL7CRtBlE3khCVmH5tVRDJzhwdgrDj26eFemBiDPb3w4uLYp3mdsKAOY59E5Hnt7eHI6uys72uaT2MOqZm1A23u/kz9Qqpecbu3QqFQoKP066+V3f9+eGQjzD8BJkZhbBDmvQTmzIfRQ3Dm1bC8TrfNRaQmhoaG6AyZv9Pdh6q5Zlp95O4+4u7PmNkdZvaCLdfM7KVmdsd0fqbU0PJeWLUJrC0kcQiPNhdWbVYSF2lSMx1+eCFwtZndambHZcrnA+qQjal7PSyaMvZp8RroXhclHBGpv9mMI78YWAzsNrPTahOOzJo7HNgB7V1w4nnhcf/2xi3DJiINN5tE/iSh9d0H/LOZXViLgGSWBvvguWHo2Qprd0PPlnA+WMexTyIS1UwXzXJ4fjPkq8zsWuA24BO1CkxmqGMpvOlRmFe8+XvKpfDGvdBWx7FPIhLVTBP5pPlJ7v5xM/s+8IXZhySz0tYejqz5dR77JCJRzTSRnw4czBa4+1fM7GFgxayjEhGRqs0okbv7oxXKHwIemlVEIiIyLVqPXEQkcUrkIiKJUyIXEUmcErnEp8lKIrOiRC5xjY/DihXhUURmRIlc4tq5M6y4v2tX7EhEkjXTceQiMzcwAIXiuum9vTA2Fh6XFNdN7+wM+2SJSFWitsjN7D1m9oCZDRWPe8zs9TFjkgbYsAGWLYOeHrjrrlC2Y0c4X7YsPC8iVYvdtbIP+BCwnDAj9A7g78zs56NGJfXV2wubNkFbGwwOhrLBQZg7FzZvDs+LSNWiJnJ3/wd3/0d3/4G7P+Luv0/YwPn8mHFJA6xfD2umrJu+Zg2sWxclHJGU5aaP3MzagCuA44B7KtRpB7IrQi1sQGhSD+6hO6WrC7q7ob8ftm8P5WbHvFxEjojdtYKZnW1mw4SNmz8DXObu36tQ/RqgkDn2NSZKqbm+Phgehq1bYfdu2LIlnPdp3XSR6ZrW5st1CcBsPnAq0Am8BfhVYHW5ZF6hRb5Pmy8naGQkHNl/t0IBFix44RbkIi1kJpsvR+9acfdR4IfF0/vNbCXwm8C7ytQdIbTcATB9BU9Xe/sLE3an1k0XmYnoXStlzGFyq1tERI4iaovczK4Dvg48RugmuQq4EFgbMay4dLNPRKYpdov8ZOAvgD3APwErgbXufnvUqGKZGIfbVoRHEZEqRW2Ru/s7Y75+7hzcCYU+OLgLFq2OHY2IJCL6zc6Wd3gAxorrjjzcCxNjsKcXXlxcd2ReJyzQuiMiUpkSeWwPbYBHNsL8E2BiNJTt3wHbemD0EJx5NSzXlHWRPIt9ayt2H7ks74VVm8DaYGwwlI0Ngs2FVZuVxEVyLg9L6iuR50H3elg0Zd2RxWuge12UcESkenlYUl9dK3ngDgd2QHsXHN8Nw/2wf3v872siUlbeltRXizwPBvvguWHo2Qprd0PPlnA+qHVHRPIob0vqR19rZTbMrAMoJL/WyvgITIzAvMzfYbQAbQugTZNcRfLo5pvh2mvhySePlL3sZSGJz2Y15pmstaIWeR60tU9O4gDzO5XERXIsT0vqq49cRGQG8rSkvlrkIiIzkKcl9dVHLiIyA/VaUj/J9chFRFKUpyX11bUiIpI4JXIRkcQpkYuIJE6JXEQkcUrkIiKJUyIXEUmcErmISOKUyOsp4clWIpIOJfJ6mRiH21aERxGROlIir5eDO6HQBwcjbhsyXfoGIZIkJfJaOjwAT/eH4+FemBiDPb1Hyg4PxI6wsjxsPCiSgDy2d5TIa+mhDfDVZbCtBw4Utw3ZvyOcf3VZeD6v8rDxoEjO5bW9o0ReS8t7YdUmsDYYGwxlY4Ngc2HV5vB8ngwMhEWU+/snbzxYKhvI8TcIkQjy2t5RIq+17vWwaMq2IYvXQPe6KOEcVd42HhTJoRTaO0rkteYOB3ZAexeceF543L89nx1rvb2waRO0tcHgYCgbHIS5c2Hz5vC8SItLob2jRF5rg33w3DD0bIW1u6FnSzgfjLBtSDXytPGgSA6l0N6JukOQmV0DXA4sA54F7gZ+1933VHl9/nYIGh+BiZHJmymPFqBtQT43U3aHU0+Fw4ePbDy4YAE89ljjNx4UybG3vz1s55Y9/8u/rP3rzGSHoNgt8tXAjcD5wCXAPGCbmR0XNarZaGufnMQB5nfmM4lDvjYeFMmp7EbL550XHksbLedBrvbsNLOTgAPAane/q4r6+WuRp6ZeGw+KNJEHHoDVq+GWW+CSS2DbNrjyypDcX/Wq2r5WM+zZWdrx7lC5J82sHchml4V1j6jZ5WnjQZGcWroUHn30SHvn0kth797Q3smD2F0rzzOzOcD1wLfc/cEK1a4BCpljX2OiE5FW1t4++UsrhPZOXr605iaRE/rKzwLedpQ61xFa7aVjSQPiEhHJtVx0rZjZRuANwAXuXrGV7e4jwEjmugZEJyKSb1ETuYVM/GfAZcCF7v6jmPGIiJS4pzMCN3bXyo3A24GrgKfNbHHxeFHkuESkheV1caxKYify9xD6urcDT2aOKyPGVL0cDd0UkdrJ6+JYlUTtWnH3RL64lDExDt9YBWvvgzltsaMRkVkaGAhTKGDy4lhLikMqOjvDRKA8it0iT1eKOwCJSEUpLI5ViRL5dKS8A5CIHFUKi2NVokQ+HSnvACQix5TqYqBK5NOR2g5AIjIteV8cqxIl8ulKaQcgEZmWVBcDzcXMzqRkdwA6vhuG+4/sAJTK7AERKSvvi2NVohb5dKW2A5CIHFOp6yTvi2NVkqv1yKcrynrkedoBSN8CRGZtfBxWrYL77gsjVmJLcYeg9ORlB6DU5hCL5FRqszjLUR95qrKfvtWrY0cjkpSUZ3GWo0Sekmb79IlEsmEDbNwIJ5wAo6OhrDSL89AhuPrqfE8AmkpdKylJeQ6xSI6kPIuzHCXyauTlhnCzffpEIkp1Fmc5SuTHMjEOt60Ij3nQTJ8+kQhK7bJUZ3GWo0R+LHlb5bCZPn0iDZYd7JXqLM5ydLOznMMDMFa8qZhd5fDFxZuK8zphQaSbiqVP3y23wCWXwLZtcOWVofxVr4oTk0gisoO9zj8/zVmc5WhCUDn3vx8e2QjzT4CJ0bAw1ryXwJz5MHoIzrw63gJZIyPhyP59C4Xw6cv79DORCLKDvT7wAfj7v4c3vhH+5E9CWd4Ge2lCUK3keZXDVOcQi0TSCoO9lMgr0SqHIslzb43BXkrklWRXOTzxvPBYWuVQRHIve2Oz2Qd7KZFXolUORZKWvbHZ7IO9NGqlko6l8KZHjyyQdcql8Ma9YZVDEcmlSqtYDA+H8ptvhre8pfkGeymRV9LW/sIVDed3xolFRKpSaQ2Vu++GZ5+Fb30rJPKUhxqWo66VSprlO5dIizjajc158154Y7OZBnspkZfTyGn5+oUhMmutdGOzHCXycho1LV+bQ4jURCvd2CxHibzk8AA83R+O7LT8Utnhgdq/ZjNsTSISgXu4sdnfH47sjc1//MdwY/PTn05/DZVqRb3ZaWYXAL8DLAdOAS5z91ujBPPQhsnT8gH274BtPbWdlq/NIURmpbTH5mtfC5/6VOvd2Cwndov8OOC7wHsjx9G4afmtMF9YpI5KX2Qvv7w1b2yWEzWRu/vX3f1ad//bmHE8rxHT8lthvrBIjVXqRlm9GlaunFy32W9slpPUOHIzaweyv1cX1vQFstPyj++G4f4j0/LNavc669fDHXeEzruSVvz0iVTpaHtsHjgQuk1+4RdCoi/d2Kzlf9m8i921Ml3XAIXMsa+mP71R0/Jb8ba6yCxU+iLrDi96UViatlVubJaTWiK/DujMHEtq+tNL0/JPuSScl6bldyyt6cs01dYkIg1Sbnz4xRfDU0+FPVbgyI3NpTX+L5t3udlYwsycaY5aqdvGEvWmzSFEps0dTj0VDh+G7u7QjbJgATz2WHN1o2hjiVRocwiRadMX2cpijyM/Hnhlpuh0MzsHOOTuj8WJSkTyaOnS5tljs9Zij1pZAdyZOf9k8fELwLqGR1NLrXbbXKTO2ttf+KW1UwuSApETubtvB5ov25Wmnt13X7jNLiJSR+ojrwetoSIiDRS7a6V5aA0VEYlELfJa0RoqIhKJEnmtaA0VEYlEibyWWnFrEhGJTn3ktZRdQ6U09awVV/ARkYZSi7yWNPVMRCLIzVorMxF9rZWpLW2toSIis6S1Vhqp3MbJWkNFRCJQIp8pTfoRkZzQzc7p0KQfEckhtcinQ5N+RCSHlMinQ5N+RCSHlMirVRrdo0k/IpIzSuTVyI5Q0cbJIpIzSuTVyI5Q0aQfEckZJfJKBgbCFPv+/skjVObOhTvvhHPPDfVaddtuEckNJfJKPv7x8iNULroIVq6cPEJFk35EJCIl8nLGx0N3yqc/rREqIpJ7SuTllPrEzzhDI1REJPc0s7Ok0qzNe++Fl74UTj8dHntMy9KKSO6oRV5SbtbmHXfA/v0wNAQXXKARKiKSS2qRl/T2wtlnw7XXHukTf/ppWLwYrrvuSHfK3r1hWVoRkZxQizyr3KzNiy+e3CeuESoikjNqkWdpqzYRSZBa5FmatSkiCdJWb1naqk1EIpvJVm/qWslqb39hwg5vqIhIbuWia8XM3mtme83ssJnda2ar6v2aH/nIR+r9EiIiDRE9kZvZlcAngT8AXg18F/iGmZ1cr9c8ePAgN910EwMDA/V6CRGRhomeyIEPAJvd/XPu/j3g3cAzwPp6veDmzZt54okn+OxnP1uvlxARaZioidzM5gPLgW+Wytx9onj+mjL1282so3QAC6t9ra1bt9Ld3U13dzc33HADANdff/3zZVu3bp3tX0dEJIrYLfIuoA3YP6V8P7C4TP1rgELm2FftC1111VV86Utfoquri/37w8vt37+fk046iS9/+ctcddVVM4lfRCS62Il8uq4DOjPHkulcvHLlSt761rdOKrviiitYsWJFzQIUEWm02Il8ABgHFk0pXwQ8NbWyu4+4+1DpAJ6e7gvu3LmT7u5u3ve+99Hd3c3OnTtnFLiISF5ETeTuPgrcD7yuVGZmc4rn99T69Z555hkGBga4/fbbueGGG7j99tsZGBjg2WefrfVLiYg0TPSZncXhh18A3gXcB/wW8FZgmbtP7Tufeu20ZnaW/q6WWTelXJmISCxJzux09y+b2UnAHxJucP4r8IvHSuIzUS5ZK4GLSOqiJ3IAd98IbIwdh4hIimLf7BQRkVlSIhcRSZwSuYhI4pTIRUQSp0QuIpI4JXIRkcTlYvjhbA0NVTVmXkQk92aSz6LP7JwNM/sZprECoohIQpa4+79XUzH1RG7Ay5j+4lkLCb8Alszg2jxQ/HEp/rhaIf6FwBNeZYJOumul+Jes6jdWVmZa/tPVrmWQJ4o/LsUfV4vEP62/l252iogkTolcRCRxrZrIR4A/KD6mSPHHpfjjUvxTJH2zU0REWrdFLiLSNJTIRUQSp0QuIpI4JXIRkcQ1bSI3s/ea2V4zO2xm95rZqmPUv8LMHi7W7zOzX2pUrBXiqTp+M1tnZj7lONzIeKfEc4GZ/YOZPVGM5U1VXHOhmX3bzEbM7Idmtq7+kVaMZVrxF2Of+v67mS1uUMjZWK4xs382s6fN7ICZ3WpmS6u4Lhef/5nEn6fPv5m9x8weMLOh4nGPmb3+GNfM+r1vykRuZlcCnyQM8Xk18F3gG2Z2coX6PcAXgZuAc4FbgVvN7KyGBPzCeKYVf9EQcErm+Nl6x3kUxxFifm81lc3sdOBrwJ3AOcD1wGfNbG2d4juWacWfsZTJ/wYHahxXNVYDNwLnA5cA84BtZnZcpQty9vmfdvxFefn87wM+BCwHVgB3AH9nZj9frnLN3nt3b7oDuBfYmDmfQ5jK/6EK9b8MfHVK2W7gM4nEvw4YjP2+V4jNgTcdo84ngAenlH0JuC2R+C8s1ntJ7HjLxHZSMbYLjlInV5//GcSf289/Mb5DwDvr+d43XYvczOYTfht+s1Tm7hPF89dUuOw12fpF3zhK/bqZYfwAx5vZo2b2uJlVbAHkVG7e/1n6VzN70sxuN7P/GDuYos7i46Gj1Mnz+19N/JDDz7+ZtZnZ2wjf8O6pUK0m733TJXKgC2gD9k8p3w9U6rNcPM369TST+PcA64E3Am8n/LvebWZL6hVkjVV6/zvM7EUR4pmuJ4F3A28uHo8D283s1TGDMrM5hG6qb7n7g0epmqfP//OmEX+uPv9mdraZDRNmbn4GuMzdv1ehek3e+6RXP5TA3e8h8xvfzO4Gvg+8C/hwrLhahbvvISSTkrvNrBt4P/Df40QFhL7ms4DXRoxhNqqKP4ef/z2Eez2dwFuAL5jZ6qMk81lrxhb5ADAOLJpSvgh4qsI1T02zfj3NJP5J3H0M+A7wytqGVjeV3v8hd382Qjy1cB8R338z2wi8AbjI3Y+1+UqePv/AtOOfJPbn391H3f2H7n6/u19DuHH+mxWq1+S9b7pE7u6jwP3A60plxa9or6NyP9U92fpFlxylft3MMP5JzKwNOJvwlT8FuXn/a+gcIrz/FmwELgPWuPuPqrgsN+//DOOf+jPy9vmfA7RXeK42733sO7p1ukt8JXAYeAfwc8Am4CfAouLzfwFcl6nfA4wBvw0sAz4GjAJnJRL/R4BLgVcQhit+EXgW+A+R4j+ekMjOIYw4eH/xz6cWn78O+ItM/dOBnwL/t/j+/wbwHLA2kfh/i9A/+0pCV8D1hG9Vr4sQ+6eAQcIwvsWZ40WZOrn9/M8w/tx8/oufjQuA0wi/TK4DJoBL6vneN/w/SQPf0KuBRwk3HO4Fzss8tx34/JT6VxD6tkaAB4FfSiV+oDdT9ynCmOxzI8Z+YTEBTj0+X3z+88D2Mtd8p/h36AfWpRI/8EHgh8Xk8WPCePiLIsVeLm7Pvp95/vzPJP48ff4J48H3FmM5QBiRckm933stYysikrim6yMXEWk1SuQiIolTIhcRSZwSuYhI4pTIRUQSp0QuIpI4JXIRkcQpkYuIJE6JXEQkcUrkItNgZq8ys53F/RUfN7MPxo5JRIlcpEpm1gFsI6zrsRz4HeBjZvbrUQOTlqe1VkQyzOw0oNzSqTsI+ytuABZ7WG4YM/s/hD09lzUsSJEp1CIXmexxJu/Gfi5hRcO7CPso3lVK4kXfAJaa2UsbHahIiRK5SIa7j7v7U+7+FGFd7M8QFvn/GJX3V4TI+1tKa9OenSKV3QwsJKwnPWFmseMRKUuJXKQMM7sWWAuscveni8WV9lcsPScShbpWRKYwszcTtg97q7v3Z566B7jAzOZlyi4B9rj7TxoZo0iWRq2IZJjZWYSt9T4J3Jh5apSwD+cewhDETxD257wZeL+7/3mDQxV5nhK5SIaZrQM+V+apHe5+oZm9ipDgVwIDwJ+5+ycaGKLICyiRi4gkTn3kIiKJUyIXEUmcErmISOKUyEVEEqdELiKSOCVyEZHEKZGLiCROiVxEJHFK5CIiiVMiFxFJnBK5iEji/j+amfIAeNJExgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ "axes = plt.figure(figsize=(4, 4), dpi=100).gca()\n", "axes.set_title('z space')\n", "axes.scatter(historyGDz[:,0], historyGDz[:,1], lw=0.5, marker='*', color='blue')\n", @@ -675,19 +750,28 @@ "axes.set_ylabel('z1')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These trajectories confirm the intuition outlined in the previous sections: GD in blue gives a very sub-optimal trajectory in $\\mathbf{z}$. Newton (in orange) does better, but is still clearly curved, in contrast to the straight, and diagonal red trajectory for the PG-based optimization.\n", + "\n", + "The behavior in intermediate spaces becomes especially important when they're not only abstract latent spaces as in this example, but when they have actual physical meanings." + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusions \n", "\n", - "That concludes our simple example. And despite its simplicity, it already surprisingly large differences between gradient descent, Newton's method, and the physical gradients emerged.\n", + "That concludes our simple example. Despite its simplicity, it already showed surprisingly large differences between gradient descent, Newton's method, and the physical gradients.\n", "\n", - "The main takeaways were:\n", + "The main takeaways of this section are:\n", "* GD easily yields \"unbalanced\" updates\n", - "* Newtons method does better, but \n", + "* Newtons method does better, but is far from optimal\n", "* PGs outperform both if an inverse function is available\n", - "* Be aware of how an optimizer progresses in latent spaces\n", + "* The choice of optimizer strongly affects progress in latent spaces\n", " \n", "In the next sections we can build on these observations to use PGs for training NNs via invertible physical models." ] @@ -708,9 +792,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BFGS optimization test run, find x such that y=[2,2]:\n" + ] + }, + { + "data": { + "text/plain": [ + "array([2.00000003, 1.41421353])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "def fun_z_inv_opt(target_y, x_ini):\n", " # a bit ugly, we switch to pure scipy here inside each iteration for BFGS\n", @@ -741,9 +843,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PG iter 0: [2.09999967 2.50998022]\n", + "PG iter 1: [1.46999859 2.10000011]\n", + "PG iter 2: [1.02899871 1.75698602]\n", + "PG iter 3: [0.72029824 1.4699998 ]\n", + "PG iter 4: [0.50420733 1.22988982]\n", + "PG iter 5: [0.35294448 1.02899957]\n", + "PG iter 6: [0.24705997 0.86092355]\n", + "PG iter 7: [0.17294205 0.72030026]\n", + "PG iter 8: [0.12106103 0.60264817]\n", + "PG iter 9: [0.08474171 0.50421247]\n" + ] + } + ], "source": [ "x = np.asarray([3.,3.])\n", "eta = 0.3\n", @@ -766,14 +885,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Nice! It works, just like PG. Not much point plotting this, it's basiclly the PG version, but let's measure the difference..." + "Nice! It works, just like PG. Not much point plotting this, it's basiclly the PG version, but let's measure the difference. Below, we compute the MAE, which for this simple example turns out to be on the order of our floating point accuracy." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MAE difference between analytic PG and approximate inversion: 0.000001\n" + ] + } + ], "source": [ "historyPGa = onp.asarray(history)\n", "updatesPGa = onp.asarray(updates) \n", @@ -793,9 +920,9 @@ "\n", "Based on this code example you can try the following modifications:\n", "\n", - "- instead of the simple L(z(x)) contrsuction above, try other, more complicated functions\n", + "- Instead of the simple L(z(x)) function above, try other, more complicated functions.\n", "\n", - "- instead of the simple \"regular\" gradient descent, compare the versions above to commonly used DL optimizers such as AdaGrad, RmsProp or Adam." + "- Replace the simple \"regular\" gradient descent with another optimizer, e.g., commonly used DL optimizers such as AdaGrad, RmsProp or Adam. Compare the versions above with the new trajectories." ] }, {