diff --git a/seq2seq/attention_decoder.ipynb b/seq2seq/attention_decoder.ipynb index dddbcb6..0dd5dfc 100644 --- a/seq2seq/attention_decoder.ipynb +++ b/seq2seq/attention_decoder.ipynb @@ -230,6 +230,7 @@ " idx_2 = [[lang_2.word2index[word] for word in s.split(' ')]\n", " for s in self.pairs[:, 1]]\n", " self.idx_pairs = np.array(list(zip(idx_1, idx_2)))\n", + " self.shuffle_idx = np.arange(len(pairs))\n", " \n", " def __str__(self):\n", " return(self.pairs)\n", @@ -238,8 +239,7 @@ " np.random.shuffle(self.shuffle_idx)\n", " self.pairs = self.pairs[self.shuffle_idx]\n", " self.idx_pairs = self.idx_pairs[self.shuffle_idx] \n", - " \n", - " " + " " ] }, { @@ -329,7 +329,7 @@ ], "source": [ "class Encoder(nn.Module):\n", - " def __init__(self, n_words, embedding_size, hidden_size, bidirectional=False, device=device.type):\n", + " def __init__(self, n_words, embedding_size, hidden_size, bidirectional=False, device=device):\n", " super(Encoder, self).__init__()\n", " self.bidirectional = bidirectional\n", " self.hidden_size = hidden_size\n", @@ -401,7 +401,7 @@ ], "source": [ "class Decoder(nn.Module):\n", - " def __init__(self, embedding_size, hidden_size, output_size, device=device.type):\n", + " def __init__(self, embedding_size, hidden_size, output_size, device=device):\n", " super(Decoder, self).__init__()\n", " self.decoder = 'simple'\n", " self.hidden_size = hidden_size\n", @@ -522,6 +522,8 @@ "m = Encoder(eng.n_words, embedding_size, hidden_size, bidirectional=False, device='cpu')\n", "scentence = torch.tensor([1, 23, 9])\n", "out, h = m(scentence)\n", + "print(out.shape)\n", + "\n", "encoder_outputs = torch.zeros(max_length, out.shape[-1], device='cpu')\n", "encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n", "\n", @@ -536,18 +538,18 @@ "metadata": {}, "outputs": [], "source": [ - "def run_decoder(decoder, scentence, h, teacher_forcing=False, encoder_outputs=None):\n", + "def run_decoder(decoder, criterion, sentence, h, teacher_forcing=False, encoder_outputs=None):\n", " loss = 0\n", - " \n", - " for j in range(scentence.shape[0]):\n", + " word = torch.tensor([0], device=device) # \n", + " for j in range(sentence.shape[0]):\n", " if decoder.decoder == 'Attention':\n", " x, h = decoder(word, h, encoder_outputs)\n", " else:\n", " x, h = decoder(word, h)\n", "\n", - " loss += criterion(x.view(1, -1), scentence[j].view(-1))\n", + " loss += criterion(x.view(1, -1), sentence[j].view(-1))\n", " if teacher_forcing:\n", - " word = eng_scentence[j]\n", + " word = sentence[j]\n", " else:\n", " word = x.argmax().detach()\n", " if word.item() == 1: # \n", @@ -569,7 +571,7 @@ "encoder = Encoder(eng.n_words, embedding_size, context_vector_size, bidirectional)\n", "context_vector_size = context_vector_size * 2 if bidirectional else context_vector_size \n", "decoder = Decoder(embedding_size, context_vector_size, fra.n_words)\n", - "writer = SummaryWriter('tb/emb-100_h256_bidirectionalwRelu')" + "# writer = SummaryWriter('tb/emb-100_h256_bidirectionalwRelu')" ] }, { @@ -603,23 +605,24 @@ " fra_scentence = torch.tensor(pair[1], device=device)\n", "\n", " # Encode the input language\n", - " out, h = encoder(fra_scentence) \n", + " out, h = encoder(eng_scentence) \n", " encoder_outputs = torch.zeros(max_length, out.shape[-1], device=device)\n", " \n", " if decoder.decoder == 'attention':\n", " encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n", "\n", - " word = torch.tensor([0], device=device) # \n", " teacher_forcing = np.random.rand() < teacher_forcing_ratio\n", - " loss = run_decoder(decoder, eng_scentence, h, teacher_forcing)\n", + " loss = run_decoder(decoder, criterion, fra_scentence, h, teacher_forcing)\n", "\n", " loss.backward()\n", - " writer.add_scalar('loss', loss.cpu().item() / (j + 1))\n", + "# writer.add_scalar('loss', loss.cpu().item() / (j + 1))\n", "\n", " optim_decoder.step()\n", " optim_encoder.step()\n", "\n", - " print(f'epoch {epoch}')\n" + " print(f'epoch {epoch}')\n", + "\n", + "train(encoder, decoder)" ] }, {