diff --git a/intro-teaser.ipynb b/intro-teaser.ipynb index e80d1ab..d6e4184 100644 --- a/intro-teaser.ipynb +++ b/intro-teaser.ipynb @@ -553,7 +553,7 @@ " dt = 1./n_steps\n", " x_in = x0\n", " for i in range(n_steps):\n", - " x0 = x0 + dt * nn(x0, torch.tensor([i/n_steps]).expand(x0.shape[0], 1) )\n", + " x0 = x0 + dt * nn(x0, torch.tensor([i/n_steps]).expand(x0.shape[0], 1).to(x0.device) )\n", " x0[:,0] = x_in[:,0] # condition on original x position\n", " trajectory.append(x0)\n", " return trajectory, t\n",