Fix shapes

This commit is contained in:
Sylvain Gugger 2020-04-13 07:46:14 -07:00
parent c854e380a6
commit fb7e2b8af4
2 changed files with 4 additions and 4 deletions

View File

@ -1723,7 +1723,7 @@
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" \n",
" def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
@ -2039,7 +2039,7 @@
" self.drop = nn.Dropout(p)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h_o.weight = self.i_h.weight\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" \n",
" def forward(self, x):\n",
" raw,h = self.rnn(self.i_h(x), self.h)\n",

View File

@ -1154,7 +1154,7 @@
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" \n",
" def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
@ -1362,7 +1362,7 @@
" self.drop = nn.Dropout(p)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h_o.weight = self.i_h.weight\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" \n",
" def forward(self, x):\n",
" raw,h = self.rnn(self.i_h(x), self.h)\n",