This commit is contained in:
Sylvain Gugger
2020-04-15 05:49:39 -07:00
parent 3fdf8c070b
commit 676ebb9941
3 changed files with 6 additions and 18 deletions

View File

@@ -1154,7 +1154,7 @@
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]\n",
" \n",
" def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
@@ -1362,7 +1362,7 @@
" self.drop = nn.Dropout(p)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h_o.weight = self.i_h.weight\n",
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]\n",
" \n",
" def forward(self, x):\n",
" raw,h = self.rnn(self.i_h(x), self.h)\n",