specify n_out in call to pretrained model function to replace the head with a new one (#254)

Co-authored-by: Jeremy Howard <j@fast.ai>
This commit is contained in:
Kofi Asiedu Brempong 2020-11-29 15:06:46 +00:00 committed by GitHub
parent e132978d8d
commit 741295a8b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -179,7 +179,7 @@
}
],
"source": [
"model = xresnet50()\n",
"model = xresnet50(n_out=dls.c)\n",
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)\n",
"learn.fit_one_cycle(5, 3e-3)"
]
@ -358,7 +358,7 @@
}
],
"source": [
"model = xresnet50()\n",
"model = xresnet50(n_out=dls.c)\n",
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)\n",
"learn.fit_one_cycle(5, 3e-3)"
]
@ -472,7 +472,7 @@
],
"source": [
"dls = get_dls(128, 128)\n",
"learn = Learner(dls, xresnet50(), loss_func=CrossEntropyLossFlat(), \n",
"learn = Learner(dls, xresnet50(n_out=dls.c), loss_func=CrossEntropyLossFlat(), \n",
" metrics=accuracy)\n",
"learn.fit_one_cycle(4, 3e-3)"
]
@ -846,7 +846,7 @@
"Here is how we train a model with Mixup:\n",
"\n",
"```python\n",
"model = xresnet50()\n",
"model = xresnet50(n_out=dls.c)\n",
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), \n",
" metrics=accuracy, cbs=MixUp())\n",
"learn.fit_one_cycle(5, 3e-3)\n",
@ -937,7 +937,7 @@
"To use this in practice, we just have to change the loss function in our call to `Learner`:\n",
"\n",
"```python\n",
"model = xresnet50()\n",
"model = xresnet50(n_out=dls.c)\n",
"learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy(), \n",
" metrics=accuracy)\n",
"learn.fit_one_cycle(5, 3e-3)\n",
@ -1019,6 +1019,18 @@
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,