specify n_out in call to pretrained model function to replace the head with a new one (#254)
Co-authored-by: Jeremy Howard <j@fast.ai>
This commit is contained in:
parent
e132978d8d
commit
741295a8b1
@ -179,7 +179,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = xresnet50()\n",
|
||||
"model = xresnet50(n_out=dls.c)\n",
|
||||
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)\n",
|
||||
"learn.fit_one_cycle(5, 3e-3)"
|
||||
]
|
||||
@ -358,7 +358,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = xresnet50()\n",
|
||||
"model = xresnet50(n_out=dls.c)\n",
|
||||
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)\n",
|
||||
"learn.fit_one_cycle(5, 3e-3)"
|
||||
]
|
||||
@ -472,7 +472,7 @@
|
||||
],
|
||||
"source": [
|
||||
"dls = get_dls(128, 128)\n",
|
||||
"learn = Learner(dls, xresnet50(), loss_func=CrossEntropyLossFlat(), \n",
|
||||
"learn = Learner(dls, xresnet50(n_out=dls.c), loss_func=CrossEntropyLossFlat(), \n",
|
||||
" metrics=accuracy)\n",
|
||||
"learn.fit_one_cycle(4, 3e-3)"
|
||||
]
|
||||
@ -846,7 +846,7 @@
|
||||
"Here is how we train a model with Mixup:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"model = xresnet50()\n",
|
||||
"model = xresnet50(n_out=dls.c)\n",
|
||||
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), \n",
|
||||
" metrics=accuracy, cbs=MixUp())\n",
|
||||
"learn.fit_one_cycle(5, 3e-3)\n",
|
||||
@ -937,7 +937,7 @@
|
||||
"To use this in practice, we just have to change the loss function in our call to `Learner`:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"model = xresnet50()\n",
|
||||
"model = xresnet50(n_out=dls.c)\n",
|
||||
"learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy(), \n",
|
||||
" metrics=accuracy)\n",
|
||||
"learn.fit_one_cycle(5, 3e-3)\n",
|
||||
@ -1019,8 +1019,20 @@
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user