diff --git a/physgrad-code.ipynb b/physgrad-code.ipynb index 02716e3..219463a 100644 --- a/physgrad-code.ipynb +++ b/physgrad-code.ipynb @@ -347,7 +347,7 @@ "def loss_function(net, x_gt: CenteredGrid, sip: bool):\n", " y_target = diffuse.fourier(x_gt, 8., 1)\n", " with math.precision(32):\n", - " prediction = field.native_call(net, field.to_float(y_target)).vector[0]\n", + " prediction = field.native_call(net, field.to_float(y_target))\n", " prediction += field.mean(x_gt) - field.mean(prediction)\n", " x = field.stop_gradient(prediction)\n", " if sip:\n",