more explenations in seq2seq

This commit is contained in:
ritchie46
2018-12-05 14:58:06 +01:00
parent 980182c5c7
commit 7d891445bd
2 changed files with 142 additions and 24 deletions

View File

@@ -30,7 +30,39 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 2814k 100 2814k 0 0 507k 0 0:00:05 0:00:05 --:--:-- 631k\n",
"Archive: data.zip\n",
" creating: data/\n",
" inflating: data/eng-fra.txt \n",
" creating: data/names/\n",
" inflating: data/names/Arabic.txt \n",
" inflating: data/names/Chinese.txt \n",
" inflating: data/names/Czech.txt \n",
" inflating: data/names/Dutch.txt \n",
" inflating: data/names/English.txt \n",
" inflating: data/names/French.txt \n",
" inflating: data/names/German.txt \n",
" inflating: data/names/Greek.txt \n",
" inflating: data/names/Irish.txt \n",
" inflating: data/names/Italian.txt \n",
" inflating: data/names/Japanese.txt \n",
" inflating: data/names/Korean.txt \n",
" inflating: data/names/Polish.txt \n",
" inflating: data/names/Portuguese.txt \n",
" inflating: data/names/Russian.txt \n",
" inflating: data/names/Scottish.txt \n",
" inflating: data/names/Spanish.txt \n",
" inflating: data/names/Vietnamese.txt \n"
]
}
],
"source": [
"# download the needed data\n",
"if not os.path.isfile('data.zip'):\n",
@@ -39,24 +71,68 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" de question !\n",
"Really?\tVraiment?\n",
"Really?\tVrai ?\n",
"Really?\tAh bon ?\n",
"Thanks.\tMerci !\n",
"We try.\tOn essaye.\n",
"We won.\tNous avons gagné.\n",
"We won.\tNous gagnâmes.\n",
"We won.\tNous l'avons emporté.\n",
"We won.\tNous l'empor\n"
]
}
],
"source": [
"# Take a quick view of the data.\n",
"with open('data/eng-fra.txt') as f:\n",
" f.seek(1000)\n",
" print(f.read(200))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class Lang:\n",
" \"\"\"\n",
" Utility class that serves as a language dictionary\n",
" \"\"\"\n",
" def __init__(self, name):\n",
" self.name = name\n",
" # Count how often a word occurs in the language data.\n",
" self.word2count = {}\n",
" # Words are mapped to indices and vice versa\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.word2index = {v:k for k, v in self.index2word.items()}\n",
" # Total word count\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def add_sentence(self, sentence):\n",
" \"\"\"\n",
" Process words in a sentence string.\n",
" \n",
" :param sentence: (str) \n",
" \"\"\"\n",
" for word in sentence.split(' '):\n",
" self.add_word(word)\n",
"\n",
" def add_word(self, word):\n",
" \"\"\"\n",
" Process words\n",
" :param word: (str)\n",
" \"\"\"\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
@@ -66,6 +142,9 @@
" self.word2count[word] += 1\n",
" \n",
" def translate_indexes(self, idx):\n",
" \"\"\"\n",
" Takes in a vector of indices and returns the sentence.\n",
" \"\"\"\n",
" return [self.index2word[i] for i in idx]\n",
" \n",
"# Turn a Unicode string to plain ASCII, thanks to\n",
@@ -106,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -137,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -165,7 +244,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -187,18 +266,23 @@
"array(['we are even EOS', 'nous sommes a egalite EOS'], dtype='<U60')"
]
},
"execution_count": 5,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def prepare_data(lang1, lang2, reverse=False):\n",
" # read_langs initialized the Lang objects (still empty) and returns the pair sentences.\n",
" input_lang, output_lang, pairs = read_langs(lang1, lang2, reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" \n",
" # Reduce data. We haven't got all day to train a model.\n",
" pairs = filter_pairs(pairs) \n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n",
" \n",
" # Process the language pairs.\n",
" for pair in pairs:\n",
" input_lang.add_sentence(pair[0])\n",
" output_lang.add_sentence(pair[1])\n",
@@ -220,17 +304,32 @@
"\n",
"The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.\n",
"\n",
"![](img/encoder-network.png)"
"![](img/encoder-network.png)\n",
"\n",
"Every output could be seen as the context of the sentence up to that point.\n",
"\n",
"![](img/training_seq2seq_many2may.svg)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"execution_count": 20,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([5, 1, 2])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=True, device=device.type):\n",
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=False, device=device.type):\n",
" super(Encoder, self).__init__()\n",
" self.bidirectional = bidirectional\n",
" self.hidden_size = hidden_size\n",
@@ -242,8 +341,6 @@
" if device == 'cuda':\n",
" self.cuda()\n",
" \n",
" self.out = nn.Softmax\n",
" \n",
" def forward(self, x):\n",
" # shape (seq_length, batch_size, input_size)\n",
" dense_vector = self.embedding(x).view(x.shape[0], 1, -1)\n",
@@ -258,9 +355,10 @@
" return x, h\n",
" \n",
"\n",
"m = Encoder(eng.n_words, 10, 2, True, 'cpu')\n",
"m = Encoder(eng.n_words, 10, 2, False, 'cpu')\n",
"scentence = torch.tensor([400, 1, 2, 6, 8])\n",
"a = m(scentence)\n"
"a = m(scentence)\n",
"a[0].shape"
]
},
{
@@ -273,21 +371,30 @@
"\n",
"At every step of decoding, the decoder is given an input token and hidden state. The initial input token is the start-of-string <SOS> token, and the first hidden state is the context vector (the encoders last hidden state).\n",
" \n",
"![](img/decoder-network.png)"
"![](img/decoder-network.png)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-23348.0801, grad_fn=<SumBackward0>)"
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 10])\n",
"torch.Size([1, 1, 10])\n"
]
},
"execution_count": 123,
{
"data": {
"text/plain": [
"tensor(-23351.8633, grad_fn=<SumBackward0>)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@@ -312,6 +419,11 @@
" self.cuda()\n",
" \n",
" def forward(self, word, h):\n",
" \"\"\"\n",
" :param word: (tensor) Last word or start of sentence token.\n",
" :param h: (tensor) Hidden state or context tensor.\n",
" \"\"\"\n",
" # map from shape (seq_len, embedding_size) to (seq_len, batch, embedding_size) (Notel: seq length is the number of words in the sentence)\n",
" word_embedding = self.embedding(word).view(h.shape[0], 1, -1)\n",
" a = self.relu(word_embedding)\n",
" x, h = self.rnn(a, h)\n",
@@ -357,6 +469,8 @@
" nn.Embedding(output_size, embedding_size),\n",
" nn.Dropout(dropout)\n",
" )\n",
" \n",
" # Seperate neural network to learn the attention weights\n",
" self.attention_weights = nn.Sequential(\n",
" nn.Linear(embedding_size + hidden_size, max_length),\n",
" nn.Softmax(2)\n",
@@ -377,12 +491,14 @@
" \n",
" def forward(self, word, h, encoder_outputs):\n",
" \"\"\"\n",
" :param word: (LongTensor) The word indexes\n",
" :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder\n",
" :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder\n",
" :param word: (LongTensor) The word indices. This is the last activated word or \n",
" :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder.\n",
" :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder.\n",
" \"\"\"\n",
" # map from shape (seq_len, embedding_size) to (seq_len, batch, embedding_size) (Notel: seq length is the number of words in the sentence)\n",
" word_embedding = self.embedding(word).view(1, 1, -1)\n",
" \n",
" # Concatenate the word embedding and the last hidden state, so that attention weights can be determined.\n",
" x = torch.cat((word_embedding, h), 2)\n",
" attention_weights = self.attention_weights(x)\n",
" # attention applied\n",
@@ -849,7 +965,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
"version": "3.7.1"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 8.4 KiB