more explenations in seq2seq
This commit is contained in:
@@ -30,7 +30,39 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
|
||||||
|
" Dload Upload Total Spent Left Speed\n",
|
||||||
|
"100 2814k 100 2814k 0 0 507k 0 0:00:05 0:00:05 --:--:-- 631k\n",
|
||||||
|
"Archive: data.zip\n",
|
||||||
|
" creating: data/\n",
|
||||||
|
" inflating: data/eng-fra.txt \n",
|
||||||
|
" creating: data/names/\n",
|
||||||
|
" inflating: data/names/Arabic.txt \n",
|
||||||
|
" inflating: data/names/Chinese.txt \n",
|
||||||
|
" inflating: data/names/Czech.txt \n",
|
||||||
|
" inflating: data/names/Dutch.txt \n",
|
||||||
|
" inflating: data/names/English.txt \n",
|
||||||
|
" inflating: data/names/French.txt \n",
|
||||||
|
" inflating: data/names/German.txt \n",
|
||||||
|
" inflating: data/names/Greek.txt \n",
|
||||||
|
" inflating: data/names/Irish.txt \n",
|
||||||
|
" inflating: data/names/Italian.txt \n",
|
||||||
|
" inflating: data/names/Japanese.txt \n",
|
||||||
|
" inflating: data/names/Korean.txt \n",
|
||||||
|
" inflating: data/names/Polish.txt \n",
|
||||||
|
" inflating: data/names/Portuguese.txt \n",
|
||||||
|
" inflating: data/names/Russian.txt \n",
|
||||||
|
" inflating: data/names/Scottish.txt \n",
|
||||||
|
" inflating: data/names/Spanish.txt \n",
|
||||||
|
" inflating: data/names/Vietnamese.txt \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# download the needed data\n",
|
"# download the needed data\n",
|
||||||
"if not os.path.isfile('data.zip'):\n",
|
"if not os.path.isfile('data.zip'):\n",
|
||||||
@@ -39,24 +71,68 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" de question !\n",
|
||||||
|
"Really?\tVraiment ?\n",
|
||||||
|
"Really?\tVrai ?\n",
|
||||||
|
"Really?\tAh bon ?\n",
|
||||||
|
"Thanks.\tMerci !\n",
|
||||||
|
"We try.\tOn essaye.\n",
|
||||||
|
"We won.\tNous avons gagné.\n",
|
||||||
|
"We won.\tNous gagnâmes.\n",
|
||||||
|
"We won.\tNous l'avons emporté.\n",
|
||||||
|
"We won.\tNous l'empor\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Take a quick view of the data.\n",
|
||||||
|
"with open('data/eng-fra.txt') as f:\n",
|
||||||
|
" f.seek(1000)\n",
|
||||||
|
" print(f.read(200))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"\n",
|
"\n",
|
||||||
"class Lang:\n",
|
"class Lang:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Utility class that serves as a language dictionary\n",
|
||||||
|
" \"\"\"\n",
|
||||||
" def __init__(self, name):\n",
|
" def __init__(self, name):\n",
|
||||||
" self.name = name\n",
|
" self.name = name\n",
|
||||||
|
" # Count how often a word occurs in the language data.\n",
|
||||||
" self.word2count = {}\n",
|
" self.word2count = {}\n",
|
||||||
|
" # Words are mapped to indices and vice versa\n",
|
||||||
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
|
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
|
||||||
" self.word2index = {v:k for k, v in self.index2word.items()}\n",
|
" self.word2index = {v:k for k, v in self.index2word.items()}\n",
|
||||||
|
" # Total word count\n",
|
||||||
" self.n_words = 2 # Count SOS and EOS\n",
|
" self.n_words = 2 # Count SOS and EOS\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def add_sentence(self, sentence):\n",
|
" def add_sentence(self, sentence):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Process words in a sentence string.\n",
|
||||||
|
" \n",
|
||||||
|
" :param sentence: (str) \n",
|
||||||
|
" \"\"\"\n",
|
||||||
" for word in sentence.split(' '):\n",
|
" for word in sentence.split(' '):\n",
|
||||||
" self.add_word(word)\n",
|
" self.add_word(word)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def add_word(self, word):\n",
|
" def add_word(self, word):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Process words\n",
|
||||||
|
" :param word: (str)\n",
|
||||||
|
" \"\"\"\n",
|
||||||
" if word not in self.word2index:\n",
|
" if word not in self.word2index:\n",
|
||||||
" self.word2index[word] = self.n_words\n",
|
" self.word2index[word] = self.n_words\n",
|
||||||
" self.word2count[word] = 1\n",
|
" self.word2count[word] = 1\n",
|
||||||
@@ -66,6 +142,9 @@
|
|||||||
" self.word2count[word] += 1\n",
|
" self.word2count[word] += 1\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def translate_indexes(self, idx):\n",
|
" def translate_indexes(self, idx):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Takes in a vector of indices and returns the sentence.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
" return [self.index2word[i] for i in idx]\n",
|
" return [self.index2word[i] for i in idx]\n",
|
||||||
" \n",
|
" \n",
|
||||||
"# Turn a Unicode string to plain ASCII, thanks to\n",
|
"# Turn a Unicode string to plain ASCII, thanks to\n",
|
||||||
@@ -106,7 +185,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 90,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -137,7 +216,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 91,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -165,7 +244,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -187,18 +266,23 @@
|
|||||||
"array(['we are even EOS', 'nous sommes a egalite EOS'], dtype='<U60')"
|
"array(['we are even EOS', 'nous sommes a egalite EOS'], dtype='<U60')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"def prepare_data(lang1, lang2, reverse=False):\n",
|
"def prepare_data(lang1, lang2, reverse=False):\n",
|
||||||
|
" # read_langs initialized the Lang objects (still empty) and returns the pair sentences.\n",
|
||||||
" input_lang, output_lang, pairs = read_langs(lang1, lang2, reverse)\n",
|
" input_lang, output_lang, pairs = read_langs(lang1, lang2, reverse)\n",
|
||||||
" print(\"Read %s sentence pairs\" % len(pairs))\n",
|
" print(\"Read %s sentence pairs\" % len(pairs))\n",
|
||||||
|
" \n",
|
||||||
|
" # Reduce data. We haven't got all day to train a model.\n",
|
||||||
" pairs = filter_pairs(pairs) \n",
|
" pairs = filter_pairs(pairs) \n",
|
||||||
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
|
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
|
||||||
" print(\"Counting words...\")\n",
|
" print(\"Counting words...\")\n",
|
||||||
|
" \n",
|
||||||
|
" # Process the language pairs.\n",
|
||||||
" for pair in pairs:\n",
|
" for pair in pairs:\n",
|
||||||
" input_lang.add_sentence(pair[0])\n",
|
" input_lang.add_sentence(pair[0])\n",
|
||||||
" output_lang.add_sentence(pair[1])\n",
|
" output_lang.add_sentence(pair[1])\n",
|
||||||
@@ -220,17 +304,32 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.\n",
|
"The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.\n",
|
||||||
"\n",
|
"\n",
|
||||||
""
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Every output could be seen as the context of the sentence up to that point.\n",
|
||||||
|
"\n",
|
||||||
|
""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 96,
|
"execution_count": 20,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"torch.Size([5, 1, 2])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"class Encoder(nn.Module):\n",
|
"class Encoder(nn.Module):\n",
|
||||||
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=True, device=device.type):\n",
|
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=False, device=device.type):\n",
|
||||||
" super(Encoder, self).__init__()\n",
|
" super(Encoder, self).__init__()\n",
|
||||||
" self.bidirectional = bidirectional\n",
|
" self.bidirectional = bidirectional\n",
|
||||||
" self.hidden_size = hidden_size\n",
|
" self.hidden_size = hidden_size\n",
|
||||||
@@ -241,9 +340,7 @@
|
|||||||
" self.device = device\n",
|
" self.device = device\n",
|
||||||
" if device == 'cuda':\n",
|
" if device == 'cuda':\n",
|
||||||
" self.cuda()\n",
|
" self.cuda()\n",
|
||||||
" \n",
|
" \n",
|
||||||
" self.out = nn.Softmax\n",
|
|
||||||
" \n",
|
|
||||||
" def forward(self, x):\n",
|
" def forward(self, x):\n",
|
||||||
" # shape (seq_length, batch_size, input_size)\n",
|
" # shape (seq_length, batch_size, input_size)\n",
|
||||||
" dense_vector = self.embedding(x).view(x.shape[0], 1, -1)\n",
|
" dense_vector = self.embedding(x).view(x.shape[0], 1, -1)\n",
|
||||||
@@ -258,9 +355,10 @@
|
|||||||
" return x, h\n",
|
" return x, h\n",
|
||||||
" \n",
|
" \n",
|
||||||
"\n",
|
"\n",
|
||||||
"m = Encoder(eng.n_words, 10, 2, True, 'cpu')\n",
|
"m = Encoder(eng.n_words, 10, 2, False, 'cpu')\n",
|
||||||
"scentence = torch.tensor([400, 1, 2, 6, 8])\n",
|
"scentence = torch.tensor([400, 1, 2, 6, 8])\n",
|
||||||
"a = m(scentence)\n"
|
"a = m(scentence)\n",
|
||||||
|
"a[0].shape"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -273,21 +371,30 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"At every step of decoding, the decoder is given an input token and hidden state. The initial input token is the start-of-string <SOS> token, and the first hidden state is the context vector (the encoder’s last hidden state).\n",
|
"At every step of decoding, the decoder is given an input token and hidden state. The initial input token is the start-of-string <SOS> token, and the first hidden state is the context vector (the encoder’s last hidden state).\n",
|
||||||
" \n",
|
" \n",
|
||||||
""
|
"\n",
|
||||||
|
" "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 123,
|
"execution_count": 24,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch.Size([1, 10])\n",
|
||||||
|
"torch.Size([1, 1, 10])\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"tensor(-23348.0801, grad_fn=<SumBackward0>)"
|
"tensor(-23351.8633, grad_fn=<SumBackward0>)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 123,
|
"execution_count": 24,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -312,6 +419,11 @@
|
|||||||
" self.cuda()\n",
|
" self.cuda()\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def forward(self, word, h):\n",
|
" def forward(self, word, h):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" :param word: (tensor) Last word or start of sentence token.\n",
|
||||||
|
" :param h: (tensor) Hidden state or context tensor.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # map from shape (seq_len, embedding_size) to (seq_len, batch, embedding_size) (Notel: seq length is the number of words in the sentence)\n",
|
||||||
" word_embedding = self.embedding(word).view(h.shape[0], 1, -1)\n",
|
" word_embedding = self.embedding(word).view(h.shape[0], 1, -1)\n",
|
||||||
" a = self.relu(word_embedding)\n",
|
" a = self.relu(word_embedding)\n",
|
||||||
" x, h = self.rnn(a, h)\n",
|
" x, h = self.rnn(a, h)\n",
|
||||||
@@ -357,6 +469,8 @@
|
|||||||
" nn.Embedding(output_size, embedding_size),\n",
|
" nn.Embedding(output_size, embedding_size),\n",
|
||||||
" nn.Dropout(dropout)\n",
|
" nn.Dropout(dropout)\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" # Seperate neural network to learn the attention weights\n",
|
||||||
" self.attention_weights = nn.Sequential(\n",
|
" self.attention_weights = nn.Sequential(\n",
|
||||||
" nn.Linear(embedding_size + hidden_size, max_length),\n",
|
" nn.Linear(embedding_size + hidden_size, max_length),\n",
|
||||||
" nn.Softmax(2)\n",
|
" nn.Softmax(2)\n",
|
||||||
@@ -377,12 +491,14 @@
|
|||||||
" \n",
|
" \n",
|
||||||
" def forward(self, word, h, encoder_outputs):\n",
|
" def forward(self, word, h, encoder_outputs):\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
" :param word: (LongTensor) The word indexes\n",
|
" :param word: (LongTensor) The word indices. This is the last activated word or \n",
|
||||||
" :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder\n",
|
" :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder.\n",
|
||||||
" :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder\n",
|
" :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder.\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
|
" # map from shape (seq_len, embedding_size) to (seq_len, batch, embedding_size) (Notel: seq length is the number of words in the sentence)\n",
|
||||||
" word_embedding = self.embedding(word).view(1, 1, -1)\n",
|
" word_embedding = self.embedding(word).view(1, 1, -1)\n",
|
||||||
"\n",
|
" \n",
|
||||||
|
" # Concatenate the word embedding and the last hidden state, so that attention weights can be determined.\n",
|
||||||
" x = torch.cat((word_embedding, h), 2)\n",
|
" x = torch.cat((word_embedding, h), 2)\n",
|
||||||
" attention_weights = self.attention_weights(x)\n",
|
" attention_weights = self.attention_weights(x)\n",
|
||||||
" # attention applied\n",
|
" # attention applied\n",
|
||||||
@@ -849,7 +965,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.6.6"
|
"version": "3.7.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
2
seq2seq/img/training_seq2seq_many2may.svg
Normal file
2
seq2seq/img/training_seq2seq_many2may.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 8.4 KiB |
Reference in New Issue
Block a user