Fix shapes
This commit is contained in:
parent
c854e380a6
commit
fb7e2b8af4
@ -1723,7 +1723,7 @@
|
||||
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
||||
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
||||
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
||||
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
|
||||
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" res,h = self.rnn(self.i_h(x), self.h)\n",
|
||||
@ -2039,7 +2039,7 @@
|
||||
" self.drop = nn.Dropout(p)\n",
|
||||
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
||||
" self.h_o.weight = self.i_h.weight\n",
|
||||
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
|
||||
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" raw,h = self.rnn(self.i_h(x), self.h)\n",
|
||||
|
@ -1154,7 +1154,7 @@
|
||||
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
||||
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
||||
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
||||
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
|
||||
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" res,h = self.rnn(self.i_h(x), self.h)\n",
|
||||
@ -1362,7 +1362,7 @@
|
||||
" self.drop = nn.Dropout(p)\n",
|
||||
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
||||
" self.h_o.weight = self.i_h.weight\n",
|
||||
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
|
||||
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" raw,h = self.rnn(self.i_h(x), self.h)\n",
|
||||
|
Loading…
Reference in New Issue
Block a user