more explenations in seq2seq
This commit is contained in:
@@ -30,7 +30,39 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
|
||||
" Dload Upload Total Spent Left Speed\n",
|
||||
"100 2814k 100 2814k 0 0 507k 0 0:00:05 0:00:05 --:--:-- 631k\n",
|
||||
"Archive: data.zip\n",
|
||||
" creating: data/\n",
|
||||
" inflating: data/eng-fra.txt \n",
|
||||
" creating: data/names/\n",
|
||||
" inflating: data/names/Arabic.txt \n",
|
||||
" inflating: data/names/Chinese.txt \n",
|
||||
" inflating: data/names/Czech.txt \n",
|
||||
" inflating: data/names/Dutch.txt \n",
|
||||
" inflating: data/names/English.txt \n",
|
||||
" inflating: data/names/French.txt \n",
|
||||
" inflating: data/names/German.txt \n",
|
||||
" inflating: data/names/Greek.txt \n",
|
||||
" inflating: data/names/Irish.txt \n",
|
||||
" inflating: data/names/Italian.txt \n",
|
||||
" inflating: data/names/Japanese.txt \n",
|
||||
" inflating: data/names/Korean.txt \n",
|
||||
" inflating: data/names/Polish.txt \n",
|
||||
" inflating: data/names/Portuguese.txt \n",
|
||||
" inflating: data/names/Russian.txt \n",
|
||||
" inflating: data/names/Scottish.txt \n",
|
||||
" inflating: data/names/Spanish.txt \n",
|
||||
" inflating: data/names/Vietnamese.txt \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# download the needed data\n",
|
||||
"if not os.path.isfile('data.zip'):\n",
|
||||
@@ -39,24 +71,68 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" de question !\n",
|
||||
"Really?\tVraiment ?\n",
|
||||
"Really?\tVrai ?\n",
|
||||
"Really?\tAh bon ?\n",
|
||||
"Thanks.\tMerci !\n",
|
||||
"We try.\tOn essaye.\n",
|
||||
"We won.\tNous avons gagné.\n",
|
||||
"We won.\tNous gagnâmes.\n",
|
||||
"We won.\tNous l'avons emporté.\n",
|
||||
"We won.\tNous l'empor\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Take a quick view of the data.\n",
|
||||
"with open('data/eng-fra.txt') as f:\n",
|
||||
" f.seek(1000)\n",
|
||||
" print(f.read(200))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class Lang:\n",
|
||||
" \"\"\"\n",
|
||||
" Utility class that serves as a language dictionary\n",
|
||||
" \"\"\"\n",
|
||||
" def __init__(self, name):\n",
|
||||
" self.name = name\n",
|
||||
" # Count how often a word occurs in the language data.\n",
|
||||
" self.word2count = {}\n",
|
||||
" # Words are mapped to indices and vice versa\n",
|
||||
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
|
||||
" self.word2index = {v:k for k, v in self.index2word.items()}\n",
|
||||
" # Total word count\n",
|
||||
" self.n_words = 2 # Count SOS and EOS\n",
|
||||
"\n",
|
||||
" def add_sentence(self, sentence):\n",
|
||||
" \"\"\"\n",
|
||||
" Process words in a sentence string.\n",
|
||||
" \n",
|
||||
" :param sentence: (str) \n",
|
||||
" \"\"\"\n",
|
||||
" for word in sentence.split(' '):\n",
|
||||
" self.add_word(word)\n",
|
||||
"\n",
|
||||
" def add_word(self, word):\n",
|
||||
" \"\"\"\n",
|
||||
" Process words\n",
|
||||
" :param word: (str)\n",
|
||||
" \"\"\"\n",
|
||||
" if word not in self.word2index:\n",
|
||||
" self.word2index[word] = self.n_words\n",
|
||||
" self.word2count[word] = 1\n",
|
||||
@@ -66,6 +142,9 @@
|
||||
" self.word2count[word] += 1\n",
|
||||
" \n",
|
||||
" def translate_indexes(self, idx):\n",
|
||||
" \"\"\"\n",
|
||||
" Takes in a vector of indices and returns the sentence.\n",
|
||||
" \"\"\"\n",
|
||||
" return [self.index2word[i] for i in idx]\n",
|
||||
" \n",
|
||||
"# Turn a Unicode string to plain ASCII, thanks to\n",
|
||||
@@ -106,7 +185,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 90,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -137,7 +216,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 91,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -165,7 +244,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -187,18 +266,23 @@
|
||||
"array(['we are even EOS', 'nous sommes a egalite EOS'], dtype='<U60')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def prepare_data(lang1, lang2, reverse=False):\n",
|
||||
" # read_langs initialized the Lang objects (still empty) and returns the pair sentences.\n",
|
||||
" input_lang, output_lang, pairs = read_langs(lang1, lang2, reverse)\n",
|
||||
" print(\"Read %s sentence pairs\" % len(pairs))\n",
|
||||
" \n",
|
||||
" # Reduce data. We haven't got all day to train a model.\n",
|
||||
" pairs = filter_pairs(pairs) \n",
|
||||
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
|
||||
" print(\"Counting words...\")\n",
|
||||
" \n",
|
||||
" # Process the language pairs.\n",
|
||||
" for pair in pairs:\n",
|
||||
" input_lang.add_sentence(pair[0])\n",
|
||||
" output_lang.add_sentence(pair[1])\n",
|
||||
@@ -220,17 +304,32 @@
|
||||
"\n",
|
||||
"The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.\n",
|
||||
"\n",
|
||||
""
|
||||
"\n",
|
||||
"\n",
|
||||
"Every output could be seen as the context of the sentence up to that point.\n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 96,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([5, 1, 2])"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"class Encoder(nn.Module):\n",
|
||||
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=True, device=device.type):\n",
|
||||
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=False, device=device.type):\n",
|
||||
" super(Encoder, self).__init__()\n",
|
||||
" self.bidirectional = bidirectional\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
@@ -242,8 +341,6 @@
|
||||
" if device == 'cuda':\n",
|
||||
" self.cuda()\n",
|
||||
" \n",
|
||||
" self.out = nn.Softmax\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" # shape (seq_length, batch_size, input_size)\n",
|
||||
" dense_vector = self.embedding(x).view(x.shape[0], 1, -1)\n",
|
||||
@@ -258,9 +355,10 @@
|
||||
" return x, h\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"m = Encoder(eng.n_words, 10, 2, True, 'cpu')\n",
|
||||
"m = Encoder(eng.n_words, 10, 2, False, 'cpu')\n",
|
||||
"scentence = torch.tensor([400, 1, 2, 6, 8])\n",
|
||||
"a = m(scentence)\n"
|
||||
"a = m(scentence)\n",
|
||||
"a[0].shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -273,21 +371,30 @@
|
||||
"\n",
|
||||
"At every step of decoding, the decoder is given an input token and hidden state. The initial input token is the start-of-string <SOS> token, and the first hidden state is the context vector (the encoder’s last hidden state).\n",
|
||||
" \n",
|
||||
""
|
||||
"\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 123,
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor(-23348.0801, grad_fn=<SumBackward0>)"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 10])\n",
|
||||
"torch.Size([1, 1, 10])\n"
|
||||
]
|
||||
},
|
||||
"execution_count": 123,
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor(-23351.8633, grad_fn=<SumBackward0>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -312,6 +419,11 @@
|
||||
" self.cuda()\n",
|
||||
" \n",
|
||||
" def forward(self, word, h):\n",
|
||||
" \"\"\"\n",
|
||||
" :param word: (tensor) Last word or start of sentence token.\n",
|
||||
" :param h: (tensor) Hidden state or context tensor.\n",
|
||||
" \"\"\"\n",
|
||||
" # map from shape (seq_len, embedding_size) to (seq_len, batch, embedding_size) (Notel: seq length is the number of words in the sentence)\n",
|
||||
" word_embedding = self.embedding(word).view(h.shape[0], 1, -1)\n",
|
||||
" a = self.relu(word_embedding)\n",
|
||||
" x, h = self.rnn(a, h)\n",
|
||||
@@ -357,6 +469,8 @@
|
||||
" nn.Embedding(output_size, embedding_size),\n",
|
||||
" nn.Dropout(dropout)\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Seperate neural network to learn the attention weights\n",
|
||||
" self.attention_weights = nn.Sequential(\n",
|
||||
" nn.Linear(embedding_size + hidden_size, max_length),\n",
|
||||
" nn.Softmax(2)\n",
|
||||
@@ -377,12 +491,14 @@
|
||||
" \n",
|
||||
" def forward(self, word, h, encoder_outputs):\n",
|
||||
" \"\"\"\n",
|
||||
" :param word: (LongTensor) The word indexes\n",
|
||||
" :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder\n",
|
||||
" :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder\n",
|
||||
" :param word: (LongTensor) The word indices. This is the last activated word or \n",
|
||||
" :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder.\n",
|
||||
" :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder.\n",
|
||||
" \"\"\"\n",
|
||||
" # map from shape (seq_len, embedding_size) to (seq_len, batch, embedding_size) (Notel: seq length is the number of words in the sentence)\n",
|
||||
" word_embedding = self.embedding(word).view(1, 1, -1)\n",
|
||||
" \n",
|
||||
" # Concatenate the word embedding and the last hidden state, so that attention weights can be determined.\n",
|
||||
" x = torch.cat((word_embedding, h), 2)\n",
|
||||
" attention_weights = self.attention_weights(x)\n",
|
||||
" # attention applied\n",
|
||||
@@ -849,7 +965,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.6"
|
||||
"version": "3.7.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
2
seq2seq/img/training_seq2seq_many2may.svg
Normal file
2
seq2seq/img/training_seq2seq_many2may.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 8.4 KiB |
Reference in New Issue
Block a user