From 01b769d5b68bb34202b76fb3c1e5f4c35ab060a5 Mon Sep 17 00:00:00 2001 From: ritchie46 Date: Mon, 10 Sep 2018 20:55:47 +0200 Subject: [PATCH] attention working --- seq2seq/attention_decoder.ipynb | 857 ++++++++++++++++++++++++++++++++ 1 file changed, 857 insertions(+) create mode 100644 seq2seq/attention_decoder.ipynb diff --git a/seq2seq/attention_decoder.ipynb b/seq2seq/attention_decoder.ipynb new file mode 100644 index 0000000..f4e92d2 --- /dev/null +++ b/seq2seq/attention_decoder.ipynb @@ -0,0 +1,857 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuda\n" + ] + } + ], + "source": [ + "import unicodedata\n", + "import numpy as np\n", + "import torch\n", + "from torch import nn\n", + "import re\n", + "import os\n", + "from tensorboardX import SummaryWriter\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# download the needed data\n", + "if not os.path.isfile('data.zip'):\n", + " ! curl -o data.zip https://download.pytorch.org/tutorial/data.zip && unzip data.zip " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class Lang:\n", + " def __init__(self, name):\n", + " self.name = name\n", + " self.word2count = {}\n", + " self.index2word = {0: \"SOS\", 1: \"EOS\"}\n", + " self.word2index = {v:k for k, v in self.index2word.items()}\n", + " self.n_words = 2 # Count SOS and EOS\n", + "\n", + " def add_sentence(self, sentence):\n", + " for word in sentence.split(' '):\n", + " self.add_word(word)\n", + "\n", + " def add_word(self, word):\n", + " if word not in self.word2index:\n", + " self.word2index[word] = self.n_words\n", + " self.word2count[word] = 1\n", + " self.index2word[self.n_words] = word\n", + " self.n_words += 1\n", + " elif word != 'SOS' and word != 'EOS':\n", + " self.word2count[word] += 1\n", + " \n", + " def translate_indexes(self, idx):\n", + " return [self.index2word[i] for i in idx]\n", + " \n", + "# Turn a Unicode string to plain ASCII, thanks to\n", + "# http://stackoverflow.com/a/518232/2809427\n", + "def unicode2ascii(s):\n", + " return ''.join(\n", + " c for c in unicodedata.normalize('NFD', s)\n", + " if unicodedata.category(c) != 'Mn'\n", + " )\n", + "\n", + "# Lowercase, trim, and remove non-letter characters\n", + "def normalize_string(s):\n", + " s = unicode2ascii(s.lower().strip())\n", + " s = re.sub(r\"\\s?[.!?]\", r\" EOS\", s)\n", + " s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n", + " return s\n", + "\n", + "def read_langs(lang1, lang2, reverse=False):\n", + " print(\"Reading lines...\")\n", + "\n", + " # Read the file and split into lines\n", + " lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\\\n", + " read().strip().split('\\n')\n", + " \n", + " # Split every line into pairs and normalize\n", + " pairs = [[normalize_string(s) for s in l.split('\\t')] for l in lines]\n", + " # Reverse pairs, make Lang instances\n", + " if reverse:\n", + " pairs = [list(reversed(p)) for p in pairs]\n", + " input_lang = Lang(lang2)\n", + " output_lang = Lang(lang1)\n", + " else:\n", + " input_lang = Lang(lang1)\n", + " output_lang = Lang(lang2)\n", + "\n", + " return input_lang, output_lang, pairs" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "# Since there are a lot of example sentences and we want to train something quickly, we'll trim the data set to only relatively short and simple sentences. \n", + "# Here the maximum length is 10 words (that includes ending punctuation) and we're filtering to sentences that translate to the form \"I am\" or \"He is\" etc. \n", + "# (accounting for apostrophes replaced earlier).\n", + "\n", + "\n", + "\n", + "def filter_pairs(pairs):\n", + " MAX_LENGTH = 10\n", + " \n", + " eng_prefixes = (\n", + " \"i am \", \"i m \",\n", + " \"he is\", \"he s \",\n", + " \"she is\", \"she s\",\n", + " \"you are\", \"you re \",\n", + " \"we are\", \"we re \",\n", + " \"they are\", \"they re \"\n", + " )\n", + " \n", + " def filter_pair(p):\n", + " return len(p[0].split(' ')) < MAX_LENGTH and \\\n", + " len(p[1].split(' ')) < MAX_LENGTH \\\n", + " and p[0].startswith(eng_prefixes)\n", + " return [pair for pair in pairs if filter_pair(pair)]" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [], + "source": [ + "class Data:\n", + " def __init__(self, pairs, lang_1, lang_2):\n", + " self.pairs = np.array(pairs)\n", + " np.random.seed(9)\n", + " np.random.shuffle(self.pairs)\n", + " idx_1 = [[lang_1.word2index[word] for word in s.split(' ')] \n", + " for s in self.pairs[:, 0]]\n", + " idx_2 = [[lang_2.word2index[word] for word in s.split(' ')]\n", + " for s in self.pairs[:, 1]]\n", + " self.idx_pairs = np.array(list(zip(idx_1, idx_2)))\n", + " \n", + " def __str__(self):\n", + " return(self.pairs)\n", + " \n", + " def shuffle(self):\n", + " np.random.shuffle(self.shuffle_idx)\n", + " self.pairs = self.pairs[self.shuffle_idx]\n", + " self.idx_pairs = self.idx_pairs[self.shuffle_idx] \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading lines...\n", + "Read 135842 sentence pairs\n", + "Trimmed to 10853 sentence pairs\n", + "Counting words...\n", + "Counted words:\n", + "eng 2922\n", + "fra 4486\n" + ] + }, + { + "data": { + "text/plain": [ + "array(['we are even EOS', 'nous sommes a egalite EOS'], dtype='