Fix loss equations

This commit is contained in:
Sylvain Gugger
2020-04-21 20:06:19 -07:00
parent 87ec9a1b0f
commit 5fb6482ded
3 changed files with 50 additions and 12 deletions

View File

@@ -313,7 +313,8 @@
{
"data": {
"text/plain": [
"(Path('train/002844.jpg'), ['train'])"
"(Path('/home/sgugger/.fastai/data/pascal_2007/train/008663.jpg'),\n",
" ['car', 'person'])"
]
},
"execution_count": null,
@@ -488,8 +489,8 @@
{
"data": {
"text/plain": [
"tensor([-1.0028, 0.3400, -0.5906, 0.7806, 3.1160, -0.1994, 1.3180, 1.6361, -1.7553, 0.2217, 2.8052, 1.3229, 0.9369, -1.4760, -0.3204, -2.3116, -3.8615, -1.5931, 0.0745, -3.6006],\n",
" device='cuda:5', grad_fn=<SelectBackward>)"
"tensor([ 2.0258, -1.3543, 1.4640, 1.7754, -1.2820, -5.8053, 3.6130, 0.7193, -4.3683, -2.5001, -2.8373, -1.8037, 2.0122, 0.6189, 1.9729, 0.8999, -2.6769, -0.3829, 1.2212, 1.6073],\n",
" device='cuda:0', grad_fn=<SelectBackward>)"
]
},
"execution_count": null,
@@ -509,7 +510,25 @@
"source": [
"def binary_cross_entropy(inputs, targets):\n",
" inputs = inputs.sigmoid()\n",
" return torch.where(targets==1, 1-inputs, inputs).log().mean()"
" return torch.where(targets==1, inputs, 1-inputs).log().mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"binary_cross_entropy(activs, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"F.binary_cross_entropy_with_logits(activs, y)"
]
},
{