more explenations in seq2seq

This commit is contained in:
ritchie46
2018-12-05 14:58:06 +01:00
parent 980182c5c7
commit 7d891445bd
2 changed files with 142 additions and 24 deletions

View File

@@ -30,7 +30,39 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 2814k 100 2814k 0 0 507k 0 0:00:05 0:00:05 --:--:-- 631k\n",
"Archive: data.zip\n",
" creating: data/\n",
" inflating: data/eng-fra.txt \n",
" creating: data/names/\n",
" inflating: data/names/Arabic.txt \n",
" inflating: data/names/Chinese.txt \n",
" inflating: data/names/Czech.txt \n",
" inflating: data/names/Dutch.txt \n",
" inflating: data/names/English.txt \n",
" inflating: data/names/French.txt \n",
" inflating: data/names/German.txt \n",
" inflating: data/names/Greek.txt \n",
" inflating: data/names/Irish.txt \n",
" inflating: data/names/Italian.txt \n",
" inflating: data/names/Japanese.txt \n",
" inflating: data/names/Korean.txt \n",
" inflating: data/names/Polish.txt \n",
" inflating: data/names/Portuguese.txt \n",
" inflating: data/names/Russian.txt \n",
" inflating: data/names/Scottish.txt \n",
" inflating: data/names/Spanish.txt \n",
" inflating: data/names/Vietnamese.txt \n"
]
}
],
"source": [ "source": [
"# download the needed data\n", "# download the needed data\n",
"if not os.path.isfile('data.zip'):\n", "if not os.path.isfile('data.zip'):\n",
@@ -39,24 +71,68 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" de question !\n",
"Really?\tVraiment?\n",
"Really?\tVrai ?\n",
"Really?\tAh bon ?\n",
"Thanks.\tMerci !\n",
"We try.\tOn essaye.\n",
"We won.\tNous avons gagné.\n",
"We won.\tNous gagnâmes.\n",
"We won.\tNous l'avons emporté.\n",
"We won.\tNous l'empor\n"
]
}
],
"source": [
"# Take a quick view of the data.\n",
"with open('data/eng-fra.txt') as f:\n",
" f.seek(1000)\n",
" print(f.read(200))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"\n", "\n",
"class Lang:\n", "class Lang:\n",
" \"\"\"\n",
" Utility class that serves as a language dictionary\n",
" \"\"\"\n",
" def __init__(self, name):\n", " def __init__(self, name):\n",
" self.name = name\n", " self.name = name\n",
" # Count how often a word occurs in the language data.\n",
" self.word2count = {}\n", " self.word2count = {}\n",
" # Words are mapped to indices and vice versa\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n", " self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.word2index = {v:k for k, v in self.index2word.items()}\n", " self.word2index = {v:k for k, v in self.index2word.items()}\n",
" # Total word count\n",
" self.n_words = 2 # Count SOS and EOS\n", " self.n_words = 2 # Count SOS and EOS\n",
"\n", "\n",
" def add_sentence(self, sentence):\n", " def add_sentence(self, sentence):\n",
" \"\"\"\n",
" Process words in a sentence string.\n",
" \n",
" :param sentence: (str) \n",
" \"\"\"\n",
" for word in sentence.split(' '):\n", " for word in sentence.split(' '):\n",
" self.add_word(word)\n", " self.add_word(word)\n",
"\n", "\n",
" def add_word(self, word):\n", " def add_word(self, word):\n",
" \"\"\"\n",
" Process words\n",
" :param word: (str)\n",
" \"\"\"\n",
" if word not in self.word2index:\n", " if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n", " self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n", " self.word2count[word] = 1\n",
@@ -66,6 +142,9 @@
" self.word2count[word] += 1\n", " self.word2count[word] += 1\n",
" \n", " \n",
" def translate_indexes(self, idx):\n", " def translate_indexes(self, idx):\n",
" \"\"\"\n",
" Takes in a vector of indices and returns the sentence.\n",
" \"\"\"\n",
" return [self.index2word[i] for i in idx]\n", " return [self.index2word[i] for i in idx]\n",
" \n", " \n",
"# Turn a Unicode string to plain ASCII, thanks to\n", "# Turn a Unicode string to plain ASCII, thanks to\n",
@@ -106,7 +185,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 90, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -137,7 +216,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 91, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -165,7 +244,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -187,18 +266,23 @@
"array(['we are even EOS', 'nous sommes a egalite EOS'], dtype='<U60')" "array(['we are even EOS', 'nous sommes a egalite EOS'], dtype='<U60')"
] ]
}, },
"execution_count": 5, "execution_count": 13,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"def prepare_data(lang1, lang2, reverse=False):\n", "def prepare_data(lang1, lang2, reverse=False):\n",
" # read_langs initialized the Lang objects (still empty) and returns the pair sentences.\n",
" input_lang, output_lang, pairs = read_langs(lang1, lang2, reverse)\n", " input_lang, output_lang, pairs = read_langs(lang1, lang2, reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n", " print(\"Read %s sentence pairs\" % len(pairs))\n",
" \n",
" # Reduce data. We haven't got all day to train a model.\n",
" pairs = filter_pairs(pairs) \n", " pairs = filter_pairs(pairs) \n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n", " print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n", " print(\"Counting words...\")\n",
" \n",
" # Process the language pairs.\n",
" for pair in pairs:\n", " for pair in pairs:\n",
" input_lang.add_sentence(pair[0])\n", " input_lang.add_sentence(pair[0])\n",
" output_lang.add_sentence(pair[1])\n", " output_lang.add_sentence(pair[1])\n",
@@ -220,17 +304,32 @@
"\n", "\n",
"The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.\n", "The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.\n",
"\n", "\n",
"![](img/encoder-network.png)" "![](img/encoder-network.png)\n",
"\n",
"Every output could be seen as the context of the sentence up to that point.\n",
"\n",
"![](img/training_seq2seq_many2may.svg)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 96, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"text/plain": [
"torch.Size([5, 1, 2])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"class Encoder(nn.Module):\n", "class Encoder(nn.Module):\n",
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=True, device=device.type):\n", " def __init__(self, n_words, embedding_size, hidden_size, bidirectional=False, device=device.type):\n",
" super(Encoder, self).__init__()\n", " super(Encoder, self).__init__()\n",
" self.bidirectional = bidirectional\n", " self.bidirectional = bidirectional\n",
" self.hidden_size = hidden_size\n", " self.hidden_size = hidden_size\n",
@@ -241,9 +340,7 @@
" self.device = device\n", " self.device = device\n",
" if device == 'cuda':\n", " if device == 'cuda':\n",
" self.cuda()\n", " self.cuda()\n",
" \n", " \n",
" self.out = nn.Softmax\n",
" \n",
" def forward(self, x):\n", " def forward(self, x):\n",
" # shape (seq_length, batch_size, input_size)\n", " # shape (seq_length, batch_size, input_size)\n",
" dense_vector = self.embedding(x).view(x.shape[0], 1, -1)\n", " dense_vector = self.embedding(x).view(x.shape[0], 1, -1)\n",
@@ -258,9 +355,10 @@
" return x, h\n", " return x, h\n",
" \n", " \n",
"\n", "\n",
"m = Encoder(eng.n_words, 10, 2, True, 'cpu')\n", "m = Encoder(eng.n_words, 10, 2, False, 'cpu')\n",
"scentence = torch.tensor([400, 1, 2, 6, 8])\n", "scentence = torch.tensor([400, 1, 2, 6, 8])\n",
"a = m(scentence)\n" "a = m(scentence)\n",
"a[0].shape"
] ]
}, },
{ {
@@ -273,21 +371,30 @@
"\n", "\n",
"At every step of decoding, the decoder is given an input token and hidden state. The initial input token is the start-of-string <SOS> token, and the first hidden state is the context vector (the encoders last hidden state).\n", "At every step of decoding, the decoder is given an input token and hidden state. The initial input token is the start-of-string <SOS> token, and the first hidden state is the context vector (the encoders last hidden state).\n",
" \n", " \n",
"![](img/decoder-network.png)" "![](img/decoder-network.png)\n",
" "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 123, "execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 10])\n",
"torch.Size([1, 1, 10])\n"
]
},
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor(-23348.0801, grad_fn=<SumBackward0>)" "tensor(-23351.8633, grad_fn=<SumBackward0>)"
] ]
}, },
"execution_count": 123, "execution_count": 24,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -312,6 +419,11 @@
" self.cuda()\n", " self.cuda()\n",
" \n", " \n",
" def forward(self, word, h):\n", " def forward(self, word, h):\n",
" \"\"\"\n",
" :param word: (tensor) Last word or start of sentence token.\n",
" :param h: (tensor) Hidden state or context tensor.\n",
" \"\"\"\n",
" # map from shape (seq_len, embedding_size) to (seq_len, batch, embedding_size) (Notel: seq length is the number of words in the sentence)\n",
" word_embedding = self.embedding(word).view(h.shape[0], 1, -1)\n", " word_embedding = self.embedding(word).view(h.shape[0], 1, -1)\n",
" a = self.relu(word_embedding)\n", " a = self.relu(word_embedding)\n",
" x, h = self.rnn(a, h)\n", " x, h = self.rnn(a, h)\n",
@@ -357,6 +469,8 @@
" nn.Embedding(output_size, embedding_size),\n", " nn.Embedding(output_size, embedding_size),\n",
" nn.Dropout(dropout)\n", " nn.Dropout(dropout)\n",
" )\n", " )\n",
" \n",
" # Seperate neural network to learn the attention weights\n",
" self.attention_weights = nn.Sequential(\n", " self.attention_weights = nn.Sequential(\n",
" nn.Linear(embedding_size + hidden_size, max_length),\n", " nn.Linear(embedding_size + hidden_size, max_length),\n",
" nn.Softmax(2)\n", " nn.Softmax(2)\n",
@@ -377,12 +491,14 @@
" \n", " \n",
" def forward(self, word, h, encoder_outputs):\n", " def forward(self, word, h, encoder_outputs):\n",
" \"\"\"\n", " \"\"\"\n",
" :param word: (LongTensor) The word indexes\n", " :param word: (LongTensor) The word indices. This is the last activated word or \n",
" :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder\n", " :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder.\n",
" :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder\n", " :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder.\n",
" \"\"\"\n", " \"\"\"\n",
" # map from shape (seq_len, embedding_size) to (seq_len, batch, embedding_size) (Notel: seq length is the number of words in the sentence)\n",
" word_embedding = self.embedding(word).view(1, 1, -1)\n", " word_embedding = self.embedding(word).view(1, 1, -1)\n",
"\n", " \n",
" # Concatenate the word embedding and the last hidden state, so that attention weights can be determined.\n",
" x = torch.cat((word_embedding, h), 2)\n", " x = torch.cat((word_embedding, h), 2)\n",
" attention_weights = self.attention_weights(x)\n", " attention_weights = self.attention_weights(x)\n",
" # attention applied\n", " # attention applied\n",
@@ -849,7 +965,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.6" "version": "3.7.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 8.4 KiB