vision_learner
This commit is contained in:
@@ -107,7 +107,7 @@
|
||||
" path, get_image_files(path), valid_pct=0.2, seed=42,\n",
|
||||
" label_func=is_cat, item_tfms=Resize(224))\n",
|
||||
"\n",
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn.fine_tune(1)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -358,7 +358,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet18, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet18, metrics=error_rate)\n",
|
||||
"learn.fine_tune(4)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -1553,7 +1553,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dls = ImageDataLoaders.from_folder(path)\n",
|
||||
"learn = cnn_learner(dls, resnet18, pretrained=False,\n",
|
||||
"learn = vision_learner(dls, resnet18, pretrained=False,\n",
|
||||
" loss_func=F.cross_entropy, metrics=accuracy)\n",
|
||||
"learn.fit_one_cycle(1, 0.1)"
|
||||
]
|
||||
|
||||
@@ -178,7 +178,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn.fine_tune(2)"
|
||||
]
|
||||
},
|
||||
@@ -499,7 +499,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn.fine_tune(1, base_lr=0.1)"
|
||||
]
|
||||
},
|
||||
@@ -509,7 +509,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"lr_min,lr_steep = learn.lr_find(suggest_funcs=(minimum, steep))"
|
||||
]
|
||||
},
|
||||
@@ -528,7 +528,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn.fine_tune(2, base_lr=3e-3)"
|
||||
]
|
||||
},
|
||||
@@ -554,7 +554,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn.fit_one_cycle(3, 3e-3)"
|
||||
]
|
||||
},
|
||||
@@ -598,7 +598,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn.fit_one_cycle(3, 3e-3)\n",
|
||||
"learn.unfreeze()\n",
|
||||
"learn.fit_one_cycle(12, lr_max=slice(1e-6,1e-4))"
|
||||
@@ -634,7 +634,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastai.callback.fp16 import *\n",
|
||||
"learn = cnn_learner(dls, resnet50, metrics=error_rate).to_fp16()\n",
|
||||
"learn = vision_learner(dls, resnet50, metrics=error_rate).to_fp16()\n",
|
||||
"learn.fine_tune(6, freeze_epochs=3)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -295,7 +295,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet18)"
|
||||
"learn = vision_learner(dls, resnet18)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -366,7 +366,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.2))\n",
|
||||
"learn = vision_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.2))\n",
|
||||
"learn.fine_tune(3, base_lr=3e-3, freeze_epochs=4)"
|
||||
]
|
||||
},
|
||||
@@ -580,7 +580,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet18, y_range=(-1,1))"
|
||||
"learn = vision_learner(dls, resnet18, y_range=(-1,1))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### cnn_learner"
|
||||
"### vision_learner"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -256,7 +256,7 @@
|
||||
"1. What is `model_meta`? Try printing it to see what's inside.\n",
|
||||
"1. Read the source code for `create_head` and make sure you understand what each line does.\n",
|
||||
"1. Look at the output of `create_head` and make sure you understand why each layer is there, and how the `create_head` source created it.\n",
|
||||
"1. Figure out how to change the dropout, layer size, and number of layers created by `cnn_learner`, and see if you can find values that result in better accuracy from the pet recognizer.\n",
|
||||
"1. Figure out how to change the dropout, layer size, and number of layers created by `vision_learner`, and see if you can find values that result in better accuracy from the pet recognizer.\n",
|
||||
"1. What does `AdaptiveConcatPool2d` do?\n",
|
||||
"1. What is \"nearest neighbor interpolation\"? How can it be used to upsample convolutional activations?\n",
|
||||
"1. What is a \"transposed convolution\"? What is another name for it?\n",
|
||||
|
||||
@@ -69,7 +69,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_learner(**kwargs):\n",
|
||||
" return cnn_learner(dls, resnet34, pretrained=False,\n",
|
||||
" return vision_learner(dls, resnet34, pretrained=False,\n",
|
||||
" metrics=accuracy, **kwargs).to_fp16()"
|
||||
]
|
||||
},
|
||||
@@ -386,7 +386,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. What is the equation for a step of SGD, in math or code (as you prefer)?\n",
|
||||
"1. What do we pass to `cnn_learner` to use a non-default optimizer?\n",
|
||||
"1. What do we pass to `vision_learner` to use a non-default optimizer?\n",
|
||||
"1. What are optimizer callbacks?\n",
|
||||
"1. What does `zero_grad` do in an optimizer?\n",
|
||||
"1. What does `step` do in an optimizer? How is it implemented in the general optimizer?\n",
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
"dls = ImageDataLoaders.from_name_func(\n",
|
||||
" path, get_image_files(path), valid_pct=0.2, seed=21,\n",
|
||||
" label_func=is_cat, item_tfms=Resize(224))\n",
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"learn.fine_tune(1)"
|
||||
]
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user