Fixed device bug in teaser example
This commit is contained in:
parent
d95c94ac58
commit
8eb2c3c7f7
@ -553,7 +553,7 @@
|
|||||||
" dt = 1./n_steps\n",
|
" dt = 1./n_steps\n",
|
||||||
" x_in = x0\n",
|
" x_in = x0\n",
|
||||||
" for i in range(n_steps):\n",
|
" for i in range(n_steps):\n",
|
||||||
" x0 = x0 + dt * nn(x0, torch.tensor([i/n_steps]).expand(x0.shape[0], 1) )\n",
|
" x0 = x0 + dt * nn(x0, torch.tensor([i/n_steps]).expand(x0.shape[0], 1).to(x0.device) )\n",
|
||||||
" x0[:,0] = x_in[:,0] # condition on original x position\n",
|
" x0[:,0] = x_in[:,0] # condition on original x position\n",
|
||||||
" trajectory.append(x0)\n",
|
" trajectory.append(x0)\n",
|
||||||
" return trajectory, t\n",
|
" return trajectory, t\n",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user