Fixed device bug in teaser example

This commit is contained in:
N_T 2025-03-31 16:23:15 +02:00
parent d95c94ac58
commit 8eb2c3c7f7

View File

@ -553,7 +553,7 @@
" dt = 1./n_steps\n", " dt = 1./n_steps\n",
" x_in = x0\n", " x_in = x0\n",
" for i in range(n_steps):\n", " for i in range(n_steps):\n",
" x0 = x0 + dt * nn(x0, torch.tensor([i/n_steps]).expand(x0.shape[0], 1) )\n", " x0 = x0 + dt * nn(x0, torch.tensor([i/n_steps]).expand(x0.shape[0], 1).to(x0.device) )\n",
" x0[:,0] = x_in[:,0] # condition on original x position\n", " x0[:,0] = x_in[:,0] # condition on original x position\n",
" trajectory.append(x0)\n", " trajectory.append(x0)\n",
" return trajectory, t\n", " return trajectory, t\n",