updated data link, increased BNN expo

This commit is contained in:
NT
2021-07-31 19:54:08 +02:00
parent 33ba170d77
commit 8ee86cf058
2 changed files with 109 additions and 94 deletions

File diff suppressed because one or more lines are too long

View File

@@ -95,25 +95,26 @@
"print(\"Torch version {}\".format(torch.__version__))\n",
"\n",
"# get training data\n",
"dir = \"./\"\n",
"if True:\n",
" # download\n",
" if not os.path.isfile('data-airfoils.npz'):\n",
" import urllib.request\n",
" url=\"https://ge.in.tum.de/download/2019-deepFlowPred/data.npz\"\n",
" import requests\n",
" print(\"Downloading training data (300MB), this can take a few minutes the first time...\")\n",
" urllib.request.urlretrieve(url, 'data-airfoils.npz')\n",
" npfile=np.load('data-airfoils.npz')\n",
"\n",
"else:\n",
" with open(\"data-airfoils.npz\", 'wb') as datafile:\n",
" resp = requests.get('https://dataserv.ub.tum.de/s/m1615239/download?path=%2F&files=dfp-data-400.npz', verify=False)\n",
" datafile.write(resp.content)\n",
"else: \n",
" # alternative: load from google drive (upload there beforehand):\n",
" from google.colab import drive\n",
" drive.mount('/content/gdrive')\n",
" npfile=np.load('gdrive/My Drive/data-airfoils.npz')\n",
"\n",
" dir = \"./gdrive/My Drive/\"\n",
"\n",
"npfile=np.load(dir+'data-airfoils.npz')\n",
" \n",
"print(\"Loaded data, {} training, {} validation samples\".format(len(npfile[\"inputs\"]),len(npfile[\"vinputs\"])))\n",
"\n",
"print(\"Size of the inputs array: \"+format(npfile[\"inputs\"].shape))\n"
"print(\"Size of the inputs array: \"+format(npfile[\"inputs\"].shape))"
]
},
{