diff --git a/seq2seq/attention_decoder.ipynb b/seq2seq/attention_decoder.ipynb index 0dd5dfc..9d8bdd3 100644 --- a/seq2seq/attention_decoder.ipynb +++ b/seq2seq/attention_decoder.ipynb @@ -460,7 +460,7 @@ ], "source": [ "class AttentionDecoder(nn.Module):\n", - " def __init__(self, embedding_size, hidden_size, output_size, dropout=0.1, max_length=10, device='cpu'):\n", + " def __init__(self, embedding_size, hidden_size, output_size, dropout=0.1, max_length=10, device=device):\n", " super(AttentionDecoder, self).__init__()\n", " self.decoder = 'attention'\n", " self.max_length = max_length\n", @@ -542,7 +542,7 @@ " loss = 0\n", " word = torch.tensor([0], device=device) # \n", " for j in range(sentence.shape[0]):\n", - " if decoder.decoder == 'Attention':\n", + " if decoder.decoder == 'attention':\n", " x, h = decoder(word, h, encoder_outputs)\n", " else:\n", " x, h = decoder(word, h)\n", @@ -570,8 +570,8 @@ "bidirectional = False\n", "encoder = Encoder(eng.n_words, embedding_size, context_vector_size, bidirectional)\n", "context_vector_size = context_vector_size * 2 if bidirectional else context_vector_size \n", - "decoder = Decoder(embedding_size, context_vector_size, fra.n_words)\n", - "# writer = SummaryWriter('tb/emb-100_h256_bidirectionalwRelu')" + "decoder = AttentionDecoder(embedding_size, context_vector_size, fra.n_words)\n", + "# writer = SummaryWriter('tb/train')" ] }, { @@ -612,10 +612,10 @@ " encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n", "\n", " teacher_forcing = np.random.rand() < teacher_forcing_ratio\n", - " loss = run_decoder(decoder, criterion, fra_scentence, h, teacher_forcing)\n", + " loss = run_decoder(decoder, criterion, fra_scentence, h, teacher_forcing, encoder_outputs)\n", "\n", " loss.backward()\n", - "# writer.add_scalar('loss', loss.cpu().item() / (j + 1))\n", + " writer.add_scalar('loss', loss.cpu().item() / (i + 1))\n", "\n", " optim_decoder.step()\n", " optim_encoder.step()\n", @@ -631,32 +631,36 @@ "metadata": {}, "outputs": [], "source": [ - "def translate(start, end):\n", - " with torch.no_grad():\n", - " for i in range(start, end):\n", + "def translate( start, end):\n", + " \n", + " for i in range(start, end):\n", + " \n", + " pair = data.idx_pairs[i]\n", + " eng_sentence = torch.tensor(pair[0], device=device)\n", + " fra_sentence = torch.tensor(pair[1], device=device)\n", "\n", - " pair = data.idx_pairs[i]\n", - " eng_scentence = torch.tensor(pair[0], device=device)\n", - " fra_scentence = torch.tensor(pair[1], device=device)\n", + " print('English scentence:\\t', ' '.join([eng.index2word[i] for i in eng_sentence.cpu().data.numpy()][:-1]))\n", + " print('French scentence:\\t', ' '.join([fra.index2word[i] for i in fra_sentence.cpu().data.numpy()][:-1]))\n", "\n", - " print('English scentence:\\t', ' '.join([eng.index2word[i] for i in eng_scentence.cpu().data.numpy()][:-1]))\n", - " print('Real translation:\\t', ' '.join([fra.index2word[i] for i in fra_scentence.cpu().data.numpy()][:-1]))\n", - "\n", - " h = encoder(eng_scentence)\n", - " word = torch.tensor([0], device=device)\n", - "\n", - " translation = []\n", - " for j in range(fra_scentence.shape[0]):\n", - " x, h = decoder(word, h)\n", - "\n", - " word = x.argmax().detach()\n", - " translation.append(word.cpu().data.tolist())\n", - "\n", - " if word.item() == 1: # \n", - " break\n", - " print('Model translation:\\t', ' '.join([fra.index2word[i] for i in translation][:-1]), '\\n')\n", + " # Encode the input language\n", + " out, h = encoder(eng_sentence) \n", + " encoder_outputs = torch.zeros(max_length, out.shape[-1], device=device)\n", + " encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n", " \n", - "translate(10, 20)" + " word = torch.tensor([0], device=device) # \n", + " \n", + " translation = []\n", + " for j in range(eng_sentence.shape[0]):\n", + " x, h = decoder(word, h, encoder_outputs=encoder_outputs)\n", + " \n", + " word = x.argmax().detach()\n", + " translation.append(word.cpu().data.tolist())\n", + "\n", + " if word.item() == 1: # \n", + " break\n", + " print('\\nModel translation:\\t', ' '.join([eng.index2word[i] for i in translation][:-1]), '\\n\\n')\n", + " \n", + "translate(20, 60)" ] }, {