bug fixes

This commit is contained in:
ritchie46
2018-12-07 12:24:10 +01:00
parent ecf4cbe017
commit 04a3f555de

View File

@@ -460,7 +460,7 @@
], ],
"source": [ "source": [
"class AttentionDecoder(nn.Module):\n", "class AttentionDecoder(nn.Module):\n",
" def __init__(self, embedding_size, hidden_size, output_size, dropout=0.1, max_length=10, device='cpu'):\n", " def __init__(self, embedding_size, hidden_size, output_size, dropout=0.1, max_length=10, device=device):\n",
" super(AttentionDecoder, self).__init__()\n", " super(AttentionDecoder, self).__init__()\n",
" self.decoder = 'attention'\n", " self.decoder = 'attention'\n",
" self.max_length = max_length\n", " self.max_length = max_length\n",
@@ -542,7 +542,7 @@
" loss = 0\n", " loss = 0\n",
" word = torch.tensor([0], device=device) # <SOS>\n", " word = torch.tensor([0], device=device) # <SOS>\n",
" for j in range(sentence.shape[0]):\n", " for j in range(sentence.shape[0]):\n",
" if decoder.decoder == 'Attention':\n", " if decoder.decoder == 'attention':\n",
" x, h = decoder(word, h, encoder_outputs)\n", " x, h = decoder(word, h, encoder_outputs)\n",
" else:\n", " else:\n",
" x, h = decoder(word, h)\n", " x, h = decoder(word, h)\n",
@@ -570,8 +570,8 @@
"bidirectional = False\n", "bidirectional = False\n",
"encoder = Encoder(eng.n_words, embedding_size, context_vector_size, bidirectional)\n", "encoder = Encoder(eng.n_words, embedding_size, context_vector_size, bidirectional)\n",
"context_vector_size = context_vector_size * 2 if bidirectional else context_vector_size \n", "context_vector_size = context_vector_size * 2 if bidirectional else context_vector_size \n",
"decoder = Decoder(embedding_size, context_vector_size, fra.n_words)\n", "decoder = AttentionDecoder(embedding_size, context_vector_size, fra.n_words)\n",
"# writer = SummaryWriter('tb/emb-100_h256_bidirectionalwRelu')" "# writer = SummaryWriter('tb/train')"
] ]
}, },
{ {
@@ -612,10 +612,10 @@
" encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n", " encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
"\n", "\n",
" teacher_forcing = np.random.rand() < teacher_forcing_ratio\n", " teacher_forcing = np.random.rand() < teacher_forcing_ratio\n",
" loss = run_decoder(decoder, criterion, fra_scentence, h, teacher_forcing)\n", " loss = run_decoder(decoder, criterion, fra_scentence, h, teacher_forcing, encoder_outputs)\n",
"\n", "\n",
" loss.backward()\n", " loss.backward()\n",
"# writer.add_scalar('loss', loss.cpu().item() / (j + 1))\n", " writer.add_scalar('loss', loss.cpu().item() / (i + 1))\n",
"\n", "\n",
" optim_decoder.step()\n", " optim_decoder.step()\n",
" optim_encoder.step()\n", " optim_encoder.step()\n",
@@ -631,32 +631,36 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def translate(start, end):\n", "def translate( start, end):\n",
" with torch.no_grad():\n", " \n",
" for i in range(start, end):\n", " for i in range(start, end):\n",
" \n",
" pair = data.idx_pairs[i]\n",
" eng_sentence = torch.tensor(pair[0], device=device)\n",
" fra_sentence = torch.tensor(pair[1], device=device)\n",
"\n", "\n",
" pair = data.idx_pairs[i]\n", " print('English scentence:\\t', ' '.join([eng.index2word[i] for i in eng_sentence.cpu().data.numpy()][:-1]))\n",
" eng_scentence = torch.tensor(pair[0], device=device)\n", " print('French scentence:\\t', ' '.join([fra.index2word[i] for i in fra_sentence.cpu().data.numpy()][:-1]))\n",
" fra_scentence = torch.tensor(pair[1], device=device)\n",
"\n", "\n",
" print('English scentence:\\t', ' '.join([eng.index2word[i] for i in eng_scentence.cpu().data.numpy()][:-1]))\n", " # Encode the input language\n",
" print('Real translation:\\t', ' '.join([fra.index2word[i] for i in fra_scentence.cpu().data.numpy()][:-1]))\n", " out, h = encoder(eng_sentence) \n",
"\n", " encoder_outputs = torch.zeros(max_length, out.shape[-1], device=device)\n",
" h = encoder(eng_scentence)\n", " encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
" word = torch.tensor([0], device=device)\n",
"\n",
" translation = []\n",
" for j in range(fra_scentence.shape[0]):\n",
" x, h = decoder(word, h)\n",
"\n",
" word = x.argmax().detach()\n",
" translation.append(word.cpu().data.tolist())\n",
"\n",
" if word.item() == 1: # <EOS>\n",
" break\n",
" print('Model translation:\\t', ' '.join([fra.index2word[i] for i in translation][:-1]), '\\n')\n",
" \n", " \n",
"translate(10, 20)" " word = torch.tensor([0], device=device) # <SOS>\n",
" \n",
" translation = []\n",
" for j in range(eng_sentence.shape[0]):\n",
" x, h = decoder(word, h, encoder_outputs=encoder_outputs)\n",
" \n",
" word = x.argmax().detach()\n",
" translation.append(word.cpu().data.tolist())\n",
"\n",
" if word.item() == 1: # <EOS>\n",
" break\n",
" print('\\nModel translation:\\t', ' '.join([eng.index2word[i] for i in translation][:-1]), '\\n\\n')\n",
" \n",
"translate(20, 60)"
] ]
}, },
{ {