fixed incorrect method call - torch.stack should be torch.cat, and fixed incorrect variable name (#338)

Co-authored-by: Matus <test@test.com>
This commit is contained in:
Matus-Dubrava 2020-11-29 15:06:13 +01:00 committed by GitHub
parent b8b6b02aa4
commit c2fc05e063
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1614,14 +1614,14 @@
"\n",
" def forward(self, input, state):\n",
" h,c = state\n",
" h = torch.stack([h, input], dim=1)\n",
" h = torch.cat([h, input], dim=1)\n",
" forget = torch.sigmoid(self.forget_gate(h))\n",
" c = c * forget\n",
" inp = torch.sigmoid(self.input_gate(h))\n",
" cell = torch.tanh(self.cell_gate(h))\n",
" c = c + inp * cell\n",
" out = torch.sigmoid(self.output_gate(h))\n",
" h = outgate * torch.tanh(c)\n",
" h = out * torch.tanh(c)\n",
" return h, (h,c)"
]
},