diff --git a/12_nlp_dive.ipynb b/12_nlp_dive.ipynb index 011adf9..0e1b988 100644 --- a/12_nlp_dive.ipynb +++ b/12_nlp_dive.ipynb @@ -1723,7 +1723,7 @@ " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", - " self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", + " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n", " \n", " def forward(self, x):\n", " res,h = self.rnn(self.i_h(x), self.h)\n", @@ -2039,7 +2039,7 @@ " self.drop = nn.Dropout(p)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h_o.weight = self.i_h.weight\n", - " self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", + " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n", " \n", " def forward(self, x):\n", " raw,h = self.rnn(self.i_h(x), self.h)\n", diff --git a/clean/12_nlp_dive.ipynb b/clean/12_nlp_dive.ipynb index 1a9c231..9f50802 100644 --- a/clean/12_nlp_dive.ipynb +++ b/clean/12_nlp_dive.ipynb @@ -1154,7 +1154,7 @@ " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", - " self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", + " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n", " \n", " def forward(self, x):\n", " res,h = self.rnn(self.i_h(x), self.h)\n", @@ -1362,7 +1362,7 @@ " self.drop = nn.Dropout(p)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h_o.weight = self.i_h.weight\n", - " self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", + " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n", " \n", " def forward(self, x):\n", " raw,h = self.rnn(self.i_h(x), self.h)\n",