fixed paths in control code

This commit is contained in:
NT
2021-04-09 14:56:26 +08:00
parent af5d7bd48a
commit 93a3ca9ed8

View File

@@ -163,7 +163,7 @@
"pretrain_data_path = 'moving-squares'\n", "pretrain_data_path = 'moving-squares'\n",
"shape_library = load_shapes('PDE-Control/notebooks/shapes')" "shape_library = load_shapes('PDE-Control/notebooks/shapes')"
], ],
"execution_count": 2, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@@ -194,7 +194,7 @@
" pylab.subplot(1, len(shape_library), t+1)\n", " pylab.subplot(1, len(shape_library), t+1)\n",
" pylab.imshow(shape_library[t], origin='lower')\n" " pylab.imshow(shape_library[t], origin='lower')\n"
], ],
"execution_count": 3, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "display_data", "output_type": "display_data",
@@ -241,7 +241,7 @@
" [scene.write_sim_frame([start], ['density'], frame=f) for f in range(step_count)]\n", " [scene.write_sim_frame([start], ['density'], frame=f) for f in range(step_count)]\n",
" scene.write_sim_frame([end__], ['density'], frame=step_count)" " scene.write_sim_frame([end__], ['density'], frame=step_count)"
], ],
"execution_count": 4, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@@ -296,7 +296,7 @@
" density = AABox(lower=pos-size//2, upper=pos-size//2+size).value_at(domain.center_points())\n", " density = AABox(lower=pos-size//2, upper=pos-size//2+size).value_at(domain.center_points())\n",
" scene.write_sim_frame([density], ['density'], frame=frame)" " scene.write_sim_frame([density], ['density'], frame=frame)"
], ],
"execution_count": 5, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@@ -335,7 +335,7 @@
"val_range = range(100, 200)\n", "val_range = range(100, 200)\n",
"train_range = range(200, 1000)" "train_range = range(200, 1000)"
], ],
"execution_count": 6, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@@ -354,7 +354,7 @@
"Consequently, no sequence needs to be simulated (`sequence_class=None`) and an observation loss is required at frame $\\frac n 2$ (`obs_loss_frames=[n // 2]`).\n", "Consequently, no sequence needs to be simulated (`sequence_class=None`) and an observation loss is required at frame $\\frac n 2$ (`obs_loss_frames=[n // 2]`).\n",
"The pretrained network checkpoints are stored in `supervised_checkpoints`.\n", "The pretrained network checkpoints are stored in `supervised_checkpoints`.\n",
"\n", "\n",
"*Note: The next cell will run for some time. If you have a set of pretrained networks, you can skip it and load the pretrained networks instead (see instructions below).*" "*Note: The next cell will run for some time. The PDE-Control git repo comes with a set of pre-trained networks. So if you want to focus on the evaluation, you can skip the training and load the pretrained networks instead by commenting out the training cells, and uncommenting the cells for loading below.*"
] ]
}, },
{ {
@@ -387,9 +387,9 @@
"outputId": "994c7fdc-a5aa-4769-eab8-1f27f29ad082" "outputId": "994c7fdc-a5aa-4769-eab8-1f27f29ad082"
}, },
"source": [ "source": [
"supervised_checkpoints" "supervised_checkpoints # this is where the checkpoints end up when re-training:"
], ],
"execution_count": 8, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "execute_result", "output_type": "execute_result",
@@ -414,9 +414,9 @@
"id": "jD7nKXCv30dl" "id": "jD7nKXCv30dl"
}, },
"source": [ "source": [
"# supervised_checkpoints = {'OP%d' % n: '../networks/shapes/supervised/OP%d_1000' % n for n in [2, 4, 8, 16]}" "# supervised_checkpoints = {'OP%d' % n: 'PDE-Control/networks/shapes/supervised/OP%d_1000' % n for n in [2, 4, 8, 16]}"
], ],
"execution_count": 9, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@@ -461,9 +461,9 @@
"id": "-KOcgr5M30dn" "id": "-KOcgr5M30dn"
}, },
"source": [ "source": [
"# supervised_checkpoints['CFE'] = '../networks/shapes/CFE/CFE_2000'" "# supervised_checkpoints['CFE'] = 'PDE-Control/networks/shapes/CFE/CFE_2000'"
], ],
"execution_count": 11, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@@ -521,7 +521,7 @@
" obs_loss_frames=[step_count], trainable_networks=['CFE', 'OP2', 'OP4', 'OP8', 'OP16'],\n", " obs_loss_frames=[step_count], trainable_networks=['CFE', 'OP2', 'OP4', 'OP8', 'OP16'],\n",
" sequence_class=StaggeredSequence, learning_rate=5e-4).prepare()" " sequence_class=StaggeredSequence, learning_rate=5e-4).prepare()"
], ],
"execution_count": 12, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@@ -608,7 +608,7 @@
"source": [ "source": [
"The next cell initializes the networks using the supervised checkpoints and then trains all networks jointly. You can increase the number of optimization steps or execute the next cell multiple times to further increase performance.\n", "The next cell initializes the networks using the supervised checkpoints and then trains all networks jointly. You can increase the number of optimization steps or execute the next cell multiple times to further increase performance.\n",
"\n", "\n",
"*Note: The next cell will run for some time. Optionally, you can skip this cell and load a pretrained networks instead.*" "*Note: The next cell will run for some time. Optionally, you can skip this cell and load a pretrained networks instead with code in the cell below.*"
] ]
}, },
{ {
@@ -619,7 +619,7 @@
"source": [ "source": [
"staggered_app.load_checkpoints(supervised_checkpoints)\n", "staggered_app.load_checkpoints(supervised_checkpoints)\n",
"for i in range(1000):\n", "for i in range(1000):\n",
" staggered_app.progress() # Run staggered Optimization for one batch\n", " staggered_app.progress() # run staggered Optimization for one batch\n",
"staggered_checkpoint = staggered_app.save_model()" "staggered_checkpoint = staggered_app.save_model()"
], ],
"execution_count": null, "execution_count": null,
@@ -631,10 +631,10 @@
"id": "xpLvDj5-30dq" "id": "xpLvDj5-30dq"
}, },
"source": [ "source": [
"# staggered_checkpoint = {net: '../networks/shapes/staggered/all_53750' for net in ['CFE', 'OP2', 'OP4', 'OP8', 'OP16']}\n", "# staggered_checkpoint = {net: 'PDE-Control/networks/shapes/staggered/all_53750' for net in ['CFE', 'OP2', 'OP4', 'OP8', 'OP16']}\n",
"# staggered_app.load_checkpoints(staggered_checkpoint)" "# staggered_app.load_checkpoints(staggered_checkpoint)"
], ],
"execution_count": 14, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@@ -657,7 +657,7 @@
"source": [ "source": [
"states = staggered_app.infer_all_frames(test_range)" "states = staggered_app.infer_all_frames(test_range)"
], ],
"execution_count": 15, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@@ -699,7 +699,7 @@
" pylab.title('target')\n", " pylab.title('target')\n",
" pylab.imshow(testset[1][i,...,0], origin='lower')\n" " pylab.imshow(testset[1][i,...,0], origin='lower')\n"
], ],
"execution_count": 16, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "display_data", "output_type": "display_data",
@@ -746,7 +746,7 @@
" errors.append( solution/initial )\n", " errors.append( solution/initial )\n",
"print(\"Relative MAE: \"+format(np.mean(errors)))" "print(\"Relative MAE: \"+format(np.mean(errors)))"
], ],
"execution_count": 41, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",