cleanup of airfoil case
This commit is contained in:
@@ -640,7 +640,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.7.6"
|
"version": "3.8.5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -135,9 +135,9 @@
|
|||||||
" c = np.concatenate(c,axis=1)\n",
|
" c = np.concatenate(c,axis=1)\n",
|
||||||
" display(Image.fromarray( cm.magma(c, bytes=True) ))\n",
|
" display(Image.fromarray( cm.magma(c, bytes=True) ))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"num=72\n",
|
"NUM=72\n",
|
||||||
"print(\"\\nHere are all 3 inputs are shown at the top (mask,in x, in y) \\nSide by side with the 3 output channels (p,vx,vy) at the bottom:\")\n",
|
"print(\"\\nHere are all 3 inputs are shown at the top (mask,in x, in y) \\nSide by side with the 3 output channels (p,vx,vy) at the bottom:\")\n",
|
||||||
"showSbs(npfile[\"inputs\"][num],npfile[\"targets\"][num])\n"
|
"showSbs(npfile[\"inputs\"][NUM],npfile[\"targets\"][NUM])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -173,14 +173,14 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# some global training parameters\n",
|
"# some global training constants\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# number of training epochs\n",
|
"# number of training epochs\n",
|
||||||
"epochs = 100\n",
|
"EPOCHS = 100\n",
|
||||||
"# batch size\n",
|
"# batch size\n",
|
||||||
"batch_size = 10\n",
|
"BATCH_SIZE = 10\n",
|
||||||
"# learning rate\n",
|
"# learning rate\n",
|
||||||
"lr = 0.00002\n",
|
"LR = 0.00002\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class DfpDataset():\n",
|
"class DfpDataset():\n",
|
||||||
" def __init__(self, inputs,targets): \n",
|
" def __init__(self, inputs,targets): \n",
|
||||||
@@ -196,8 +196,8 @@
|
|||||||
"tdata = DfpDataset(npfile[\"inputs\"],npfile[\"targets\"])\n",
|
"tdata = DfpDataset(npfile[\"inputs\"],npfile[\"targets\"])\n",
|
||||||
"vdata = DfpDataset(npfile[\"vinputs\"],npfile[\"vtargets\"])\n",
|
"vdata = DfpDataset(npfile[\"vinputs\"],npfile[\"vtargets\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"trainLoader = torch.utils.data.DataLoader(tdata, batch_size=batch_size, shuffle=True , drop_last=True) \n",
|
"trainLoader = torch.utils.data.DataLoader(tdata, batch_size=BATCH_SIZE, shuffle=True , drop_last=True) \n",
|
||||||
"valiLoader = torch.utils.data.DataLoader(vdata, batch_size=batch_size, shuffle=False, drop_last=True) \n",
|
"valiLoader = torch.utils.data.DataLoader(vdata, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) \n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Training & validation batches: {} , {}\".format(len(trainLoader),len(valiLoader) ))"
|
"print(\"Training & validation batches: {} , {}\".format(len(trainLoader),len(valiLoader) ))"
|
||||||
]
|
]
|
||||||
@@ -326,9 +326,9 @@
|
|||||||
"id": "QAl3VgKVQSI3"
|
"id": "QAl3VgKVQSI3"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Initialize net...\n",
|
"Next, we can initialize an instance of the `DfpNet`.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The `expo` parameter here controls the exponent for the feature maps of our Unet: this directly scales the network size (3 gives a model with ca. 150k parameters). This is relatively small for a generative model for $3 \\times 128^2 = \\text{ca. }49k$ outputs, but yields fast training times and prevents overfitting given the relatively small data set we're using here. Hence it's a good starting point."
|
"Here, the `EXPO` parameter here controls the exponent for the feature maps of our Unet: this directly scales the network size (3 gives a model with ca. 150k parameters). This is relatively small for a generative model for $3 \\times 128^2 = \\text{ca. }49k$ outputs, but yields fast training times and prevents overfitting given the relatively small data set we're using here. Hence it's a good starting point."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -426,10 +426,10 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# channel exponent to control network size\n",
|
"# channel exponent to control network size\n",
|
||||||
"expo = 3\n",
|
"EXPO = 3\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# setup network\n",
|
"# setup network\n",
|
||||||
"net = DfpNet(channelExponent=expo)\n",
|
"net = DfpNet(channelExponent=EXPO)\n",
|
||||||
"#print(net) # to double check the details...\n",
|
"#print(net) # to double check the details...\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_parameters = filter(lambda p: p.requires_grad, net.parameters())\n",
|
"model_parameters = filter(lambda p: p.requires_grad, net.parameters())\n",
|
||||||
@@ -441,10 +441,10 @@
|
|||||||
"net.apply(weights_init)\n",
|
"net.apply(weights_init)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"criterionL1 = nn.L1Loss()\n",
|
"criterionL1 = nn.L1Loss()\n",
|
||||||
"optimizerG = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0)\n",
|
"optimizerG = optim.Adam(net.parameters(), lr=LR, betas=(0.5, 0.999), weight_decay=0.0)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"targets = torch.autograd.Variable(torch.FloatTensor(batch_size, 3, 128, 128))\n",
|
"targets = torch.autograd.Variable(torch.FloatTensor(BATCH_SIZE, 3, 128, 128))\n",
|
||||||
"inputs = torch.autograd.Variable(torch.FloatTensor(batch_size, 3, 128, 128))\n"
|
"inputs = torch.autograd.Variable(torch.FloatTensor(BATCH_SIZE, 3, 128, 128))\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -455,7 +455,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Training\n",
|
"## Training\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Finally, we can train"
|
"Finally, we can train the model. This step can take a while, as we'll go over all 320 samples 100 times, and continually evaluate the validation samples to keep track of how well we're doing."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -584,11 +584,11 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"if os.path.isfile(\"model\"):\n",
|
"if os.path.isfile(\"model\"):\n",
|
||||||
" print(\"Found existing model, loading & skipping training\")\n",
|
" print(\"Found existing model, loading & skipping training\")\n",
|
||||||
" net.load_state_dict(torch.load(doLoad)) # optionally, load existing model\n",
|
" net.load_state_dict(torch.load(\"model\")) # optionally, load existing model\n",
|
||||||
"\n",
|
"\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" print(\"Training from scratch\")\n",
|
" print(\"Training from scratch\")\n",
|
||||||
" for epoch in range(epochs):\n",
|
" for epoch in range(EPOCHS):\n",
|
||||||
" net.train()\n",
|
" net.train()\n",
|
||||||
" L1_accum = 0.0\n",
|
" L1_accum = 0.0\n",
|
||||||
" for i, traindata in enumerate(trainLoader, 0):\n",
|
" for i, traindata in enumerate(trainLoader, 0):\n",
|
||||||
@@ -622,7 +622,7 @@
|
|||||||
" history_L1.append( L1_accum / len(trainLoader) )\n",
|
" history_L1.append( L1_accum / len(trainLoader) )\n",
|
||||||
" history_L1val.append( L1val_accum / len(valiLoader) )\n",
|
" history_L1val.append( L1val_accum / len(valiLoader) )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if i<3 or i%20==0:\n",
|
" if epoch<3 or epoch%20==0:\n",
|
||||||
" print( \"Epoch: {}, L1 train: {:7.5f}, L1 vali: {:7.5f}\".format(epoch, history_L1[-1], history_L1val[-1]) )\n",
|
" print( \"Epoch: {}, L1 train: {:7.5f}, L1 vali: {:7.5f}\".format(epoch, history_L1[-1], history_L1val[-1]) )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" torch.save(net.state_dict(), \"model\" )\n",
|
" torch.save(net.state_dict(), \"model\" )\n",
|
||||||
@@ -635,11 +635,9 @@
|
|||||||
"id": "4KuUpJsSL3Jv"
|
"id": "4KuUpJsSL3Jv"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Yay, the model is trained...! \n",
|
"Yay, the model is trained...! The losses nicely went down in terms of absolute values: With the standard settings from an initial value of around 0.2 for the validation loss, to ca. 0.02 after 100 epochs. \n",
|
||||||
"\n",
|
"\n",
|
||||||
"The losses nicely went down in terms of absolute values, let's look at the graphs.\n",
|
"Let's look at the graphs to get some intuition for how the trained progressed over time. This is typically important to identify longer-term trends in the training. In practice it's tricky to spot whether the overall trend of 100 or so noisy numbers in a command line log is going slightly up or down - this is much easier to spot in a visualization."
|
||||||
"\n",
|
|
||||||
"This is typically important to identify longer-term trends in the data. In practice it's tricky to spot whether the overall trend of 100 or so noisy numbers is going slightly up or down."
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -912,11 +910,11 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"## Nex steps\n",
|
"## Nex steps\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* Experiment with learning rate, dropout, and model size to reduce the error on the test set. How low can you get it with the given training data?\n",
|
"* Experiment with learning rate, dropout, and model size to reduce the error on the test set. How small can you make it with the given training data?\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* As you'll see, it's a bit limited here what you can get out of this dataset, head over to [the main github repo of this project](https://github.com/thunil/Deep-Flow-Prediction) to download larger data sets, or generate own data\n",
|
"* As you'll see, it's a bit limited here what you can get out of this dataset, head over to [the main github repo of this project](https://github.com/thunil/Deep-Flow-Prediction) to download larger data sets, or generate own data\n",
|
||||||
"\n",
|
"\n",
|
||||||
"**TODO us: provide data with \"errors\" (nan & huge neg number in 1 cell), filter out to make model train...**\n"
|
"**(TODO, us: for exercise, provide data with \"errors\" (nan & huge neg number in 1 cell), filter out to make model train...)**\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -947,7 +945,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.7.6"
|
"version": "3.8.5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
Reference in New Issue
Block a user