specify n_out in call to pretrained model function to replace the head with a new one (#254)
Co-authored-by: Jeremy Howard <j@fast.ai>
This commit is contained in:
parent
e132978d8d
commit
741295a8b1
@ -179,7 +179,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"model = xresnet50()\n",
|
"model = xresnet50(n_out=dls.c)\n",
|
||||||
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)\n",
|
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)\n",
|
||||||
"learn.fit_one_cycle(5, 3e-3)"
|
"learn.fit_one_cycle(5, 3e-3)"
|
||||||
]
|
]
|
||||||
@ -358,7 +358,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"model = xresnet50()\n",
|
"model = xresnet50(n_out=dls.c)\n",
|
||||||
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)\n",
|
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)\n",
|
||||||
"learn.fit_one_cycle(5, 3e-3)"
|
"learn.fit_one_cycle(5, 3e-3)"
|
||||||
]
|
]
|
||||||
@ -472,7 +472,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"dls = get_dls(128, 128)\n",
|
"dls = get_dls(128, 128)\n",
|
||||||
"learn = Learner(dls, xresnet50(), loss_func=CrossEntropyLossFlat(), \n",
|
"learn = Learner(dls, xresnet50(n_out=dls.c), loss_func=CrossEntropyLossFlat(), \n",
|
||||||
" metrics=accuracy)\n",
|
" metrics=accuracy)\n",
|
||||||
"learn.fit_one_cycle(4, 3e-3)"
|
"learn.fit_one_cycle(4, 3e-3)"
|
||||||
]
|
]
|
||||||
@ -846,7 +846,7 @@
|
|||||||
"Here is how we train a model with Mixup:\n",
|
"Here is how we train a model with Mixup:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
"model = xresnet50()\n",
|
"model = xresnet50(n_out=dls.c)\n",
|
||||||
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), \n",
|
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), \n",
|
||||||
" metrics=accuracy, cbs=MixUp())\n",
|
" metrics=accuracy, cbs=MixUp())\n",
|
||||||
"learn.fit_one_cycle(5, 3e-3)\n",
|
"learn.fit_one_cycle(5, 3e-3)\n",
|
||||||
@ -937,7 +937,7 @@
|
|||||||
"To use this in practice, we just have to change the loss function in our call to `Learner`:\n",
|
"To use this in practice, we just have to change the loss function in our call to `Learner`:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
"model = xresnet50()\n",
|
"model = xresnet50(n_out=dls.c)\n",
|
||||||
"learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy(), \n",
|
"learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy(), \n",
|
||||||
" metrics=accuracy)\n",
|
" metrics=accuracy)\n",
|
||||||
"learn.fit_one_cycle(5, 3e-3)\n",
|
"learn.fit_one_cycle(5, 3e-3)\n",
|
||||||
@ -1019,8 +1019,20 @@
|
|||||||
"display_name": "Python 3",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 2
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user