attention working

This commit is contained in:
ritchie46
2018-09-10 20:55:47 +02:00
parent b7a14df1c2
commit 01b769d5b6

View File

@@ -0,0 +1,857 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"import unicodedata\n",
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"import re\n",
"import os\n",
"from tensorboardX import SummaryWriter\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# download the needed data\n",
"if not os.path.isfile('data.zip'):\n",
" ! curl -o data.zip https://download.pytorch.org/tutorial/data.zip && unzip data.zip "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class Lang:\n",
" def __init__(self, name):\n",
" self.name = name\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.word2index = {v:k for k, v in self.index2word.items()}\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def add_sentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.add_word(word)\n",
"\n",
" def add_word(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" elif word != 'SOS' and word != 'EOS':\n",
" self.word2count[word] += 1\n",
" \n",
" def translate_indexes(self, idx):\n",
" return [self.index2word[i] for i in idx]\n",
" \n",
"# Turn a Unicode string to plain ASCII, thanks to\n",
"# http://stackoverflow.com/a/518232/2809427\n",
"def unicode2ascii(s):\n",
" return ''.join(\n",
" c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn'\n",
" )\n",
"\n",
"# Lowercase, trim, and remove non-letter characters\n",
"def normalize_string(s):\n",
" s = unicode2ascii(s.lower().strip())\n",
" s = re.sub(r\"\\s?[.!?]\", r\" EOS\", s)\n",
" s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n",
" return s\n",
"\n",
"def read_langs(lang1, lang2, reverse=False):\n",
" print(\"Reading lines...\")\n",
"\n",
" # Read the file and split into lines\n",
" lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\\\n",
" read().strip().split('\\n')\n",
" \n",
" # Split every line into pairs and normalize\n",
" pairs = [[normalize_string(s) for s in l.split('\\t')] for l in lines]\n",
" # Reverse pairs, make Lang instances\n",
" if reverse:\n",
" pairs = [list(reversed(p)) for p in pairs]\n",
" input_lang = Lang(lang2)\n",
" output_lang = Lang(lang1)\n",
" else:\n",
" input_lang = Lang(lang1)\n",
" output_lang = Lang(lang2)\n",
"\n",
" return input_lang, output_lang, pairs"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"# Since there are a lot of example sentences and we want to train something quickly, we'll trim the data set to only relatively short and simple sentences. \n",
"# Here the maximum length is 10 words (that includes ending punctuation) and we're filtering to sentences that translate to the form \"I am\" or \"He is\" etc. \n",
"# (accounting for apostrophes replaced earlier).\n",
"\n",
"\n",
"\n",
"def filter_pairs(pairs):\n",
" MAX_LENGTH = 10\n",
" \n",
" eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s\",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
" )\n",
" \n",
" def filter_pair(p):\n",
" return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
" len(p[1].split(' ')) < MAX_LENGTH \\\n",
" and p[0].startswith(eng_prefixes)\n",
" return [pair for pair in pairs if filter_pair(pair)]"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"class Data:\n",
" def __init__(self, pairs, lang_1, lang_2):\n",
" self.pairs = np.array(pairs)\n",
" np.random.seed(9)\n",
" np.random.shuffle(self.pairs)\n",
" idx_1 = [[lang_1.word2index[word] for word in s.split(' ')] \n",
" for s in self.pairs[:, 0]]\n",
" idx_2 = [[lang_2.word2index[word] for word in s.split(' ')]\n",
" for s in self.pairs[:, 1]]\n",
" self.idx_pairs = np.array(list(zip(idx_1, idx_2)))\n",
" \n",
" def __str__(self):\n",
" return(self.pairs)\n",
" \n",
" def shuffle(self):\n",
" np.random.shuffle(self.shuffle_idx)\n",
" self.pairs = self.pairs[self.shuffle_idx]\n",
" self.idx_pairs = self.idx_pairs[self.shuffle_idx] \n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading lines...\n",
"Read 135842 sentence pairs\n",
"Trimmed to 10853 sentence pairs\n",
"Counting words...\n",
"Counted words:\n",
"eng 2922\n",
"fra 4486\n"
]
},
{
"data": {
"text/plain": [
"array(['we are even EOS', 'nous sommes a egalite EOS'], dtype='<U60')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def prepare_data(lang1, lang2, reverse=False):\n",
" input_lang, output_lang, pairs = read_langs(lang1, lang2, reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" pairs = filter_pairs(pairs) \n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n",
" for pair in pairs:\n",
" input_lang.add_sentence(pair[0])\n",
" output_lang.add_sentence(pair[1])\n",
" print(\"Counted words:\")\n",
" print(input_lang.name, input_lang.n_words)\n",
" print(output_lang.name, output_lang.n_words)\n",
" return input_lang, output_lang, Data(pairs, input_lang, output_lang)\n",
"\n",
"\n",
"eng, fra, data = prepare_data('eng', 'fra', False)\n",
"data.pairs[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Encoder\n",
"\n",
"The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.\n",
"\n",
"![](img/encoder-network.png)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, n_words, embedding_size, hidden_size, bidirectional=True, device=device.type):\n",
" super(Encoder, self).__init__()\n",
" self.bidirectional = bidirectional\n",
" self.hidden_size = hidden_size\n",
" # The word embeddings will also be trained\n",
" # To freeze them --> m.embedding.weight.requires_grad = False\n",
" self.embedding = nn.Embedding(n_words, embedding_size) \n",
" self.rnn = nn.GRU(embedding_size, hidden_size, bidirectional=bidirectional)\n",
" self.device = device\n",
" if device == 'cuda':\n",
" self.cuda()\n",
" \n",
" self.out = nn.Softmax\n",
" \n",
" def forward(self, x):\n",
" # shape (seq_length, batch_size, input_size)\n",
" dense_vector = self.embedding(x).view(x.shape[0], 1, -1)\n",
" \n",
" # init hidden layer at beginning of sequence\n",
" n = 2 if self.bidirectional else 1\n",
" \n",
" h = torch.zeros(n, 1, self.hidden_size, device=self.device)\n",
" \n",
" x, h = self.rnn(dense_vector, h)\n",
"\n",
" return x, h\n",
" \n",
"\n",
"m = Encoder(eng.n_words, 10, 2, True, 'cpu')\n",
"scentence = torch.tensor([400, 1, 2, 6, 8])\n",
"a = m(scentence)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple Decoder\n",
"\n",
"In the simplest seq2seq decoder we use only last output of the encoder. This last output is sometimes called the context vector as it encodes context from the entire sequence. This context vector is used as the initial hidden state of the decoder.\n",
"\n",
"At every step of decoding, the decoder is given an input token and hidden state. The initial input token is the start-of-string <SOS> token, and the first hidden state is the context vector (the encoders last hidden state).\n",
" \n",
"![](img/decoder-network.png)"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-23348.0801, grad_fn=<SumBackward0>)"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class Decoder(nn.Module):\n",
" def __init__(self, embedding_size, hidden_size, output_size, device=device.type):\n",
" super(Decoder, self).__init__()\n",
" self.decoder = 'simple'\n",
" self.hidden_size = hidden_size\n",
" # Lookup table for the last word activation.\n",
" self.embedding = nn.Embedding(output_size, embedding_size)\n",
" self.relu = nn.LeakyReLU()\n",
" self.rnn = nn.GRU(embedding_size, hidden_size)\n",
" self.out = nn.Sequential(\n",
" nn.LeakyReLU(),\n",
" nn.Linear(hidden_size, output_size),\n",
" nn.LogSoftmax(2)\n",
" )\n",
" self.device = device\n",
" if device == 'cuda':\n",
" self.cuda()\n",
" \n",
" def forward(self, word, h):\n",
" word_embedding = self.embedding(word).view(h.shape[0], 1, -1)\n",
" a = self.relu(word_embedding)\n",
" x, h = self.rnn(a, h)\n",
"\n",
" return self.out(x), h\n",
"\n",
"m = Decoder(10, 20, eng.n_words, device='cpu')\n",
"m.train(False)\n",
"m(torch.tensor([1]) ,torch.zeros(1, 1, 20))[0].sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](img/attention-decoder-network.png)"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 1, 2])"
]
},
"execution_count": 120,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class AttentionDecoder(nn.Module):\n",
" def __init__(self, embedding_size, hidden_size, output_size, dropout=0.1, max_length=10, device='cpu'):\n",
" super(AttentionDecoder, self).__init__()\n",
" self.decoder = 'attention'\n",
" self.max_length = max_length\n",
" self.device = device\n",
" self.embedding = nn.Sequential(\n",
" nn.Embedding(output_size, embedding_size),\n",
" nn.Dropout(dropout)\n",
" )\n",
" self.attention_weights = nn.Sequential(\n",
" nn.Linear(embedding_size + hidden_size, max_length),\n",
" nn.Softmax(2)\n",
" )\n",
" self.attention_combine = nn.Sequential(\n",
" nn.Linear(hidden_size + embedding_size, hidden_size),\n",
" nn.ReLU()\n",
" )\n",
"\n",
" self.rnn = nn.GRU(hidden_size, hidden_size)\n",
" self.out = nn.Sequential(\n",
" nn.Linear(hidden_size, output_size),\n",
" nn.LogSoftmax(2)\n",
" )\n",
" \n",
" if device == 'cuda':\n",
" self.cuda()\n",
" \n",
" def forward(self, word, h, encoder_outputs):\n",
" \"\"\"\n",
" :param word: (LongTensor) The word indexes\n",
" :param h: (tensor) The hidden state from the previous step. In the first step, the hidden state of the encoder\n",
" :param encoder_outputs: (tensor) Zero padded (max_length, shape, shape) outputs from the encoder\n",
" \"\"\"\n",
" word_embedding = self.embedding(word).view(1, 1, -1)\n",
"\n",
" x = torch.cat((word_embedding, h), 2)\n",
" attention_weights = self.attention_weights(x)\n",
" # attention applied\n",
" x = torch.bmm(attention_weights, encoder_outputs.unsqueeze(0)) # could also be done with matmul\n",
" \n",
" # attention combined\n",
" x = torch.cat((word_embedding, x), 2)\n",
" x = self.attention_combine(x)\n",
" \n",
" x, h = self.rnn(x, h)\n",
" \n",
" x = self.out(x)\n",
"\n",
" return x, h\n",
"\n",
"\n",
"embedding_size = 256\n",
"hidden_size = 256\n",
"max_length = 10\n",
"\n",
"m = Encoder(eng.n_words, embedding_size, hidden_size, bidirectional=False, device='cpu')\n",
"scentence = torch.tensor([1, 23, 9])\n",
"out, h = m(scentence)\n",
"encoder_outputs = torch.zeros(max_length, out.shape[-1], device='cpu')\n",
"encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
"\n",
"\n",
"m = AttentionDecoder(embedding_size, hidden_size, 2, device='cpu')\n",
"m(torch.tensor([1]), h, encoder_outputs)[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
"def run_decoder(decoder, scentence, h, teacher_forcing=False, encoder_outputs=None):\n",
" loss = 0\n",
" \n",
" for j in range(scentence.shape[0]):\n",
" if decoder.decoder == 'Attention':\n",
" x, h = decoder(word, h, encoder_outputs)\n",
" else:\n",
" x, h = decoder(word, h)\n",
"\n",
" loss += criterion(x.view(1, -1), scentence[j].view(-1))\n",
" if teacher_forcing:\n",
" word = eng_scentence[j]\n",
" else:\n",
" word = x.argmax().detach()\n",
" if word.item() == 1: # <EOS>\n",
" break\n",
" return loss\n"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"teacher_forcing_ratio = 0.5\n",
"\n",
"embedding_size = 100\n",
"context_vector_size = 256\n",
"bidirectional = False\n",
"encoder = Encoder(eng.n_words, embedding_size, context_vector_size, bidirectional)\n",
"context_vector_size = context_vector_size * 2 if bidirectional else context_vector_size \n",
"decoder = Decoder(embedding_size, context_vector_size, fra.n_words)\n",
"writer = SummaryWriter('tb/emb-100_h256_bidirectionalwRelu')"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def train(encoder, decoder):\n",
" criterion = nn.NLLLoss()\n",
" optim_encoder = torch.optim.SGD(encoder.parameters(), lr=0.01)\n",
" optim_decoder = torch.optim.SGD(decoder.parameters(), lr=0.01)\n",
"\n",
" epochs = 4\n",
" batch_size = 1\n",
"\n",
" encoder.train(True)\n",
" decoder.train(True)\n",
"\n",
" for epoch in range(epochs):\n",
" data.shuffle()\n",
"\n",
" for i in range(data.pairs.shape[0]):\n",
" optim_decoder.zero_grad()\n",
" optim_encoder.zero_grad()\n",
" \n",
" pair = data.idx_pairs[i]\n",
"\n",
" eng_scentence = torch.tensor(pair[0], device=device)\n",
" fra_scentence = torch.tensor(pair[1], device=device)\n",
"\n",
" # Encode the input language\n",
" out, h = encoder(fra_scentence) \n",
" encoder_outputs = torch.zeros(max_length, out.shape[-1], device=device)\n",
" \n",
" if decoder.decoder == 'attention':\n",
" encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
"\n",
" word = torch.tensor([0], device=device) # <SOS>\n",
" teacher_forcing = np.random.rand() < teacher_forcing_ratio\n",
" loss = run_decoder(decoder, eng_scentence, h, teacher_forcing)\n",
"\n",
" loss.backward()\n",
" writer.add_scalar('loss', loss.cpu().item() / (j + 1))\n",
"\n",
" optim_decoder.step()\n",
" optim_encoder.step()\n",
"\n",
" print(f'epoch {epoch}')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def translate(start, end):\n",
" with torch.no_grad():\n",
" for i in range(start, end):\n",
"\n",
" pair = data.idx_pairs[i]\n",
" eng_scentence = torch.tensor(pair[0], device=device)\n",
" fra_scentence = torch.tensor(pair[1], device=device)\n",
"\n",
" print('English scentence:\\t', ' '.join([eng.index2word[i] for i in eng_scentence.cpu().data.numpy()][:-1]))\n",
" print('Real translation:\\t', ' '.join([fra.index2word[i] for i in fra_scentence.cpu().data.numpy()][:-1]))\n",
"\n",
" h = encoder(eng_scentence)\n",
" word = torch.tensor([0], device=device)\n",
"\n",
" translation = []\n",
" for j in range(fra_scentence.shape[0]):\n",
" x, h = decoder(word, h)\n",
"\n",
" word = x.argmax().detach()\n",
" translation.append(word.cpu().data.tolist())\n",
"\n",
" if word.item() == 1: # <EOS>\n",
" break\n",
" print('Model translation:\\t', ' '.join([fra.index2word[i] for i in translation][:-1]), '\\n')\n",
" \n",
"translate(10, 20)"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"English scentence:\t i m pretty busy\n",
"French scentence:\t je suis plutot occupe\n",
"\n",
"Model translation:\t i m pretty busy \n",
"\n",
"\n",
"English scentence:\t she sang better than him\n",
"French scentence:\t elle chanta mieux que lui\n",
"\n",
"Model translation:\t she sang better than him \n",
"\n",
"\n",
"English scentence:\t i m all for that\n",
"French scentence:\t je suis tout a fait pour\n",
"\n",
"Model translation:\t i m all for that \n",
"\n",
"\n",
"English scentence:\t they re right behind me\n",
"French scentence:\t ils se trouvent juste derriere moi\n",
"\n",
"Model translation:\t they re right behind me \n",
"\n",
"\n",
"English scentence:\t you re very funny\n",
"French scentence:\t vous etes fort droles\n",
"\n",
"Model translation:\t you re very funny \n",
"\n",
"\n",
"English scentence:\t you re very rude\n",
"French scentence:\t vous etes fort grossier\n",
"\n",
"Model translation:\t you re very rude \n",
"\n",
"\n",
"English scentence:\t he s the class clown\n",
"French scentence:\t c est le pitre de la classe\n",
"\n",
"Model translation:\t he s the class \n",
"\n",
"\n",
"English scentence:\t they re not always right\n",
"French scentence:\t elles n ont pas toujours raison\n",
"\n",
"Model translation:\t they re not always right \n",
"\n",
"\n",
"English scentence:\t we re all safe\n",
"French scentence:\t nous sommes toutes en securite\n",
"\n",
"Model translation:\t we re all safe \n",
"\n",
"\n",
"English scentence:\t i m smarter than you\n",
"French scentence:\t je suis plus astucieux que vous\n",
"\n",
"Model translation:\t i m smarter than you \n",
"\n",
"\n",
"English scentence:\t i m still your friend\n",
"French scentence:\t je suis toujours votre amie\n",
"\n",
"Model translation:\t i m still your friend \n",
"\n",
"\n",
"English scentence:\t i m asking you for your help\n",
"French scentence:\t je requiers ton aide\n",
"\n",
"Model translation:\t i m asking you \n",
"\n",
"\n",
"English scentence:\t i m used to staying up late\n",
"French scentence:\t je suis habitue a rester debout tard\n",
"\n",
"Model translation:\t i m used to staying up late \n",
"\n",
"\n",
"English scentence:\t you re very helpful\n",
"French scentence:\t vous etes fort serviables\n",
"\n",
"Model translation:\t you re very helpful \n",
"\n",
"\n",
"English scentence:\t you re the reason i came\n",
"French scentence:\t c est pour vous que je suis venu\n",
"\n",
"Model translation:\t you re the reason i came \n",
"\n",
"\n",
"English scentence:\t you re respected by everybody\n",
"French scentence:\t vous etes respecte de tous\n",
"\n",
"Model translation:\t you re respected by everybody \n",
"\n",
"\n",
"English scentence:\t i am singing with my children\n",
"French scentence:\t je chante avec mes enfants\n",
"\n",
"Model translation:\t i am singing with my \n",
"\n",
"\n",
"English scentence:\t i m sorry i don t buy it\n",
"French scentence:\t je suis desolee je ne gobe pas ca\n",
"\n",
"Model translation:\t i m sorry i don t buy it \n",
"\n",
"\n",
"English scentence:\t i m not tired yet\n",
"French scentence:\t je ne suis pas encore fatigue\n",
"\n",
"Model translation:\t i m still tired yet \n",
"\n",
"\n",
"English scentence:\t you re so pathetic\n",
"French scentence:\t tu es si pitoyable\n",
"\n",
"Model translation:\t you re so pathetic \n",
"\n",
"\n",
"English scentence:\t i m not a teacher\n",
"French scentence:\t je ne suis pas instituteur\n",
"\n",
"Model translation:\t i m not a teacher \n",
"\n",
"\n",
"English scentence:\t he is far from happy\n",
"French scentence:\t il n est vraiment pas heureux\n",
"\n",
"Model translation:\t he is far from happy \n",
"\n",
"\n",
"English scentence:\t i m worried about them\n",
"French scentence:\t je me fais du souci pour elles\n",
"\n",
"Model translation:\t i m worried about them \n",
"\n",
"\n",
"English scentence:\t she is not wrong\n",
"French scentence:\t elle n a pas tort\n",
"\n",
"Model translation:\t she s not wrong \n",
"\n",
"\n",
"English scentence:\t i am very dangerous\n",
"French scentence:\t je suis vraiment dangereux\n",
"\n",
"Model translation:\t i am very dangerous \n",
"\n",
"\n",
"English scentence:\t he is washing a car\n",
"French scentence:\t il nettoie une voiture\n",
"\n",
"Model translation:\t he is washing a \n",
"\n",
"\n",
"English scentence:\t we are concerned about our planet\n",
"French scentence:\t nous nous occupons de notre planete\n",
"\n",
"Model translation:\t we are concerned about our planet \n",
"\n",
"\n",
"English scentence:\t she isn t lonely now\n",
"French scentence:\t elle n est plus seule maintenant\n",
"\n",
"Model translation:\t she isn t lonely now \n",
"\n",
"\n",
"English scentence:\t they re all waiting\n",
"French scentence:\t elles attendent toutes\n",
"\n",
"Model translation:\t they re all \n",
"\n",
"\n",
"English scentence:\t she s been sick since last wednesday\n",
"French scentence:\t elle est malade depuis mercredi dernier\n",
"\n",
"Model translation:\t she s been sick since last \n",
"\n",
"\n",
"English scentence:\t we are listening to the radio\n",
"French scentence:\t nous sommes en train d ecouter la radio\n",
"\n",
"Model translation:\t we are listening to the radio \n",
"\n",
"\n",
"English scentence:\t i m sorry to have bothered you\n",
"French scentence:\t je suis desolee de t avoir derange\n",
"\n",
"Model translation:\t i m sorry to have bothered you \n",
"\n",
"\n",
"English scentence:\t i m not at all afraid\n",
"French scentence:\t je n ai pas peur du tout\n",
"\n",
"Model translation:\t i m not at at all \n",
"\n",
"\n",
"English scentence:\t he is pleased with his new shoes\n",
"French scentence:\t il est content de ses nouvelles chaussures\n",
"\n",
"Model translation:\t he is pleased with his new shoes \n",
"\n",
"\n",
"English scentence:\t she is a very intelligent young lady\n",
"French scentence:\t c est une jeune femme tres intelligente\n",
"\n",
"Model translation:\t she is a very intelligent young lady \n",
"\n",
"\n",
"English scentence:\t we are sorry for the inconvenience\n",
"French scentence:\t nous sommes desoles pour le derangement\n",
"\n",
"Model translation:\t we are sorry for the interruption \n",
"\n",
"\n",
"English scentence:\t i am not studying now\n",
"French scentence:\t en ce moment je n etudie pas\n",
"\n",
"Model translation:\t i am not studying now \n",
"\n",
"\n",
"English scentence:\t you re not alone\n",
"French scentence:\t vous n etes pas seule\n",
"\n",
"Model translation:\t you re not alone \n",
"\n",
"\n",
"English scentence:\t he is clumsy with his hands\n",
"French scentence:\t il est maladroit de ses mains\n",
"\n",
"Model translation:\t he is clumsy with his hands \n",
"\n",
"\n",
"English scentence:\t i am too tired to run\n",
"French scentence:\t je suis trop fatigue pour courir\n",
"\n",
"Model translation:\t i am too tired to run \n",
"\n",
"\n"
]
}
],
"source": [
"def translate( start, end):\n",
" \n",
" for i in range(start, end):\n",
" \n",
" pair = data.idx_pairs[i]\n",
" eng_scentence = torch.tensor(pair[0], device=device)\n",
" fra_scentence = torch.tensor(pair[1], device=device)\n",
"\n",
" print('English scentence:\\t', ' '.join([eng.index2word[i] for i in eng_scentence.cpu().data.numpy()][:-1]))\n",
" print('French scentence:\\t', ' '.join([fra.index2word[i] for i in fra_scentence.cpu().data.numpy()][:-1]))\n",
"\n",
" # Encode the input language\n",
" out, h = encoder(fra_scentence) \n",
" encoder_outputs = torch.zeros(max_length, out.shape[-1], device=device)\n",
" encoder_outputs[:out.shape[0], :out.shape[-1]] = out.view(out.shape[0], -1)\n",
" \n",
" word = torch.tensor([0], device=device) # <SOS>\n",
" \n",
" translation = []\n",
" for j in range(fra_scentence.shape[0]):\n",
" x, h = decoder(word, h, encoder_outputs=encoder_outputs)\n",
" \n",
" word = x.argmax().detach()\n",
" translation.append(word.cpu().data.tolist())\n",
"\n",
" if word.item() == 1: # <EOS>\n",
" break\n",
" print('\\nModel translation:\\t', ' '.join([eng.index2word[i] for i in translation][:-1]), '\\n\\n')\n",
" \n",
"translate(20, 60)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}