fixed incorrect method call - torch.stack should be torch.cat, and fixed incorrect variable name (#338)
Co-authored-by: Matus <test@test.com>
This commit is contained in:
parent
b8b6b02aa4
commit
c2fc05e063
@ -1614,14 +1614,14 @@
|
||||
"\n",
|
||||
" def forward(self, input, state):\n",
|
||||
" h,c = state\n",
|
||||
" h = torch.stack([h, input], dim=1)\n",
|
||||
" h = torch.cat([h, input], dim=1)\n",
|
||||
" forget = torch.sigmoid(self.forget_gate(h))\n",
|
||||
" c = c * forget\n",
|
||||
" inp = torch.sigmoid(self.input_gate(h))\n",
|
||||
" cell = torch.tanh(self.cell_gate(h))\n",
|
||||
" c = c + inp * cell\n",
|
||||
" out = torch.sigmoid(self.output_gate(h))\n",
|
||||
" h = outgate * torch.tanh(c)\n",
|
||||
" h = out * torch.tanh(c)\n",
|
||||
" return h, (h,c)"
|
||||
]
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user