update explanation of negative log loss (cross entropy loss) (#501)
* update explanation of nll * spelling * clean * clean * add back stuff * fix lr syntax
This commit is contained in:
@@ -396,7 +396,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tensor([1,2,3]) + tensor([1,1,1])"
|
||||
"tensor([1,2,3]) + tensor(1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -956,7 +956,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"corrects = (preds>0.0).float() == train_y\n",
|
||||
"corrects = (preds>0.5).float() == train_y\n",
|
||||
"corrects"
|
||||
]
|
||||
},
|
||||
@@ -1643,7 +1643,7 @@
|
||||
"split_at_heading": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
|
||||
@@ -123,6 +123,7 @@
|
||||
"dblock1 = DataBlock(blocks=(ImageBlock(), CategoryBlock()),\n",
|
||||
" get_y=parent_label,\n",
|
||||
" item_tfms=Resize(460))\n",
|
||||
"# Place an image in the 'images/grizzly.jpg' subfolder where this notebook is located before running this\n",
|
||||
"dls1 = dblock1.dataloaders([(Path.cwd()/'images'/'grizzly.jpg')]*100, bs=8)\n",
|
||||
"dls1.train.get_idxs = lambda: Inf.ones\n",
|
||||
"x,y = dls1.valid.one_batch()\n",
|
||||
@@ -341,7 +342,7 @@
|
||||
"df = pd.DataFrame(sm_acts, columns=[\"3\",\"7\"])\n",
|
||||
"df['targ'] = targ\n",
|
||||
"df['idx'] = idx\n",
|
||||
"df['loss'] = sm_acts[range(6), targ]\n",
|
||||
"df['result'] = sm_acts[range(6), targ]\n",
|
||||
"t = df.style.hide_index()\n",
|
||||
"#To have html code compatible with our script\n",
|
||||
"html = t._repr_html_().split('</style>')[1]\n",
|
||||
@@ -371,7 +372,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Taking the Log"
|
||||
"#### Taking the Log\n",
|
||||
"\n",
|
||||
"Recall that cross entropy loss may involve the multiplication of many numbers. Multiplying lots of negative numbers together can cause problems like [numerical underflow](https://en.wikipedia.org/wiki/Arithmetic_underflow) in computers. Therefore, we want to transform these probabilities to larger values so we can perform mathematical operations on them. There is a mathematical function that does exactly this: the *logarithm* (available as `torch.log`). It is not defined for numbers less than 0, and looks like this between 0 and 1:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -380,7 +383,38 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_function(torch.log, min=0,max=4)"
|
||||
"plot_function(torch.log, min=0,max=1, ty='log(x)', tx='x')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_function(lambda x: -1*torch.log(x), min=0,max=1, tx='x', ty='- log(x)', title = 'Log Loss when true label = 1')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.display import HTML\n",
|
||||
"df['loss'] = -torch.log(tensor(df['result']))\n",
|
||||
"t = df.style.hide_index()\n",
|
||||
"#To have html code compatible with our script\n",
|
||||
"html = t._repr_html_().split('</style>')[1]\n",
|
||||
"html = re.sub(r'<table id=\"([^\"]+)\"\\s*>', r'<table >', html)\n",
|
||||
"display(HTML(html))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Negative Log Likelihood"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -476,7 +510,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
|
||||
"lr_min,lr_steep = learn.lr_find()"
|
||||
"lr_min,lr_steep = learn.lr_find(suggest_funcs=(minimum, steep))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -675,11 +709,11 @@
|
||||
"split_at_heading": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user