bug fixes
This commit is contained in:
@@ -230,6 +230,7 @@
|
|||||||
" idx_2 = [[lang_2.word2index[word] for word in s.split(' ')]\n",
|
" idx_2 = [[lang_2.word2index[word] for word in s.split(' ')]\n",
|
||||||
" for s in self.pairs[:, 1]]\n",
|
" for s in self.pairs[:, 1]]\n",
|
||||||
" self.idx_pairs = np.array(list(zip(idx_1, idx_2)))\n",
|
" self.idx_pairs = np.array(list(zip(idx_1, idx_2)))\n",
|
||||||
|
" self.shuffle_idx = np.arange(len(pairs))\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def __str__(self):\n",
|
" def __str__(self):\n",
|
||||||
" return(self.pairs)\n",
|
" return(self.pairs)\n",
|
||||||
@@ -238,7 +239,6 @@
|
|||||||
" np.random.shuffle(self.shuffle_idx)\n",
|
" np.random.shuffle(self.shuffle_idx)\n",
|
||||||
" self.pairs = self.pairs[self.shuffle_idx]\n",
|
" self.pairs = self.pairs[self.shuffle_idx]\n",
|
||||||
" self.idx_pairs = self.idx_pairs[self.shuffle_idx] \n",
|
" self.idx_pairs = self.idx_pairs[self.shuffle_idx] \n",
|
||||||
" \n",
|
|
||||||
" "
|
" "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -329,7 +329,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"class Encoder(nn.Module):\n",
|
"class Encoder(nn.Module):\n",
|
||||||
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=False, device=device.type):\n",
|
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=False, device=device):\n",
|
||||||
" super(Encoder, self).__init__()\n",
|
" super(Encoder, self).__init__()\n",
|
||||||
" self.bidirectional = bidirectional\n",
|
" self.bidirectional = bidirectional\n",
|
||||||
" self.hidden_size = hidden_size\n",
|
" self.hidden_size = hidden_size\n",
|
||||||
@@ -401,7 +401,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"class Decoder(nn.Module):\n",
|
"class Decoder(nn.Module):\n",
|
||||||
" def __init__(self, embedding_size, hidden_size, output_size, device=device.type):\n",
|
" def __init__(self, embedding_size, hidden_size, output_size, device=device):\n",
|
||||||
" super(Decoder, self).__init__()\n",
|
" super(Decoder, self).__init__()\n",
|
||||||
" self.decoder = 'simple'\n",
|
" self.decoder = 'simple'\n",
|
||||||
" self.hidden_size = hidden_size\n",
|
" self.hidden_size = hidden_size\n",
|
||||||
@@ -522,6 +522,8 @@
|
|||||||
"m = Encoder(eng.n_words, embedding_size, hidden_size, bidirectional=False, device='cpu')\n",
|
"m = Encoder(eng.n_words, embedding_size, hidden_size, bidirectional=False, device='cpu')\n",
|
||||||
"scentence = torch.tensor([1, 23, 9])\n",
|
"scentence = torch.tensor([1, 23, 9])\n",
|
||||||
"out, h = m(scentence)\n",
|
"out, h = m(scentence)\n",
|
||||||
|
"print(out.shape)\n",
|
||||||
|
"\n",
|
||||||
"encoder_outputs = torch.zeros(max_length, out.shape[-1], device='cpu')\n",
|
"encoder_outputs = torch.zeros(max_length, out.shape[-1], device='cpu')\n",
|
||||||
"encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
|
"encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -536,18 +538,18 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def run_decoder(decoder, scentence, h, teacher_forcing=False, encoder_outputs=None):\n",
|
"def run_decoder(decoder, criterion, sentence, h, teacher_forcing=False, encoder_outputs=None):\n",
|
||||||
" loss = 0\n",
|
" loss = 0\n",
|
||||||
" \n",
|
" word = torch.tensor([0], device=device) # <SOS>\n",
|
||||||
" for j in range(scentence.shape[0]):\n",
|
" for j in range(sentence.shape[0]):\n",
|
||||||
" if decoder.decoder == 'Attention':\n",
|
" if decoder.decoder == 'Attention':\n",
|
||||||
" x, h = decoder(word, h, encoder_outputs)\n",
|
" x, h = decoder(word, h, encoder_outputs)\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" x, h = decoder(word, h)\n",
|
" x, h = decoder(word, h)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" loss += criterion(x.view(1, -1), scentence[j].view(-1))\n",
|
" loss += criterion(x.view(1, -1), sentence[j].view(-1))\n",
|
||||||
" if teacher_forcing:\n",
|
" if teacher_forcing:\n",
|
||||||
" word = eng_scentence[j]\n",
|
" word = sentence[j]\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" word = x.argmax().detach()\n",
|
" word = x.argmax().detach()\n",
|
||||||
" if word.item() == 1: # <EOS>\n",
|
" if word.item() == 1: # <EOS>\n",
|
||||||
@@ -569,7 +571,7 @@
|
|||||||
"encoder = Encoder(eng.n_words, embedding_size, context_vector_size, bidirectional)\n",
|
"encoder = Encoder(eng.n_words, embedding_size, context_vector_size, bidirectional)\n",
|
||||||
"context_vector_size = context_vector_size * 2 if bidirectional else context_vector_size \n",
|
"context_vector_size = context_vector_size * 2 if bidirectional else context_vector_size \n",
|
||||||
"decoder = Decoder(embedding_size, context_vector_size, fra.n_words)\n",
|
"decoder = Decoder(embedding_size, context_vector_size, fra.n_words)\n",
|
||||||
"writer = SummaryWriter('tb/emb-100_h256_bidirectionalwRelu')"
|
"# writer = SummaryWriter('tb/emb-100_h256_bidirectionalwRelu')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -603,23 +605,24 @@
|
|||||||
" fra_scentence = torch.tensor(pair[1], device=device)\n",
|
" fra_scentence = torch.tensor(pair[1], device=device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Encode the input language\n",
|
" # Encode the input language\n",
|
||||||
" out, h = encoder(fra_scentence) \n",
|
" out, h = encoder(eng_scentence) \n",
|
||||||
" encoder_outputs = torch.zeros(max_length, out.shape[-1], device=device)\n",
|
" encoder_outputs = torch.zeros(max_length, out.shape[-1], device=device)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" if decoder.decoder == 'attention':\n",
|
" if decoder.decoder == 'attention':\n",
|
||||||
" encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
|
" encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" word = torch.tensor([0], device=device) # <SOS>\n",
|
|
||||||
" teacher_forcing = np.random.rand() < teacher_forcing_ratio\n",
|
" teacher_forcing = np.random.rand() < teacher_forcing_ratio\n",
|
||||||
" loss = run_decoder(decoder, eng_scentence, h, teacher_forcing)\n",
|
" loss = run_decoder(decoder, criterion, fra_scentence, h, teacher_forcing)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" loss.backward()\n",
|
" loss.backward()\n",
|
||||||
" writer.add_scalar('loss', loss.cpu().item() / (j + 1))\n",
|
"# writer.add_scalar('loss', loss.cpu().item() / (j + 1))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" optim_decoder.step()\n",
|
" optim_decoder.step()\n",
|
||||||
" optim_encoder.step()\n",
|
" optim_encoder.step()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" print(f'epoch {epoch}')\n"
|
" print(f'epoch {epoch}')\n",
|
||||||
|
"\n",
|
||||||
|
"train(encoder, decoder)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user