fixed incorrect method call - torch.stack should be torch.cat, and fixed incorrect variable name (#338)
Co-authored-by: Matus <test@test.com>
This commit is contained in:
@@ -1614,14 +1614,14 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" def forward(self, input, state):\n",
|
" def forward(self, input, state):\n",
|
||||||
" h,c = state\n",
|
" h,c = state\n",
|
||||||
" h = torch.stack([h, input], dim=1)\n",
|
" h = torch.cat([h, input], dim=1)\n",
|
||||||
" forget = torch.sigmoid(self.forget_gate(h))\n",
|
" forget = torch.sigmoid(self.forget_gate(h))\n",
|
||||||
" c = c * forget\n",
|
" c = c * forget\n",
|
||||||
" inp = torch.sigmoid(self.input_gate(h))\n",
|
" inp = torch.sigmoid(self.input_gate(h))\n",
|
||||||
" cell = torch.tanh(self.cell_gate(h))\n",
|
" cell = torch.tanh(self.cell_gate(h))\n",
|
||||||
" c = c + inp * cell\n",
|
" c = c + inp * cell\n",
|
||||||
" out = torch.sigmoid(self.output_gate(h))\n",
|
" out = torch.sigmoid(self.output_gate(h))\n",
|
||||||
" h = outgate * torch.tanh(c)\n",
|
" h = out * torch.tanh(c)\n",
|
||||||
" return h, (h,c)"
|
" return h, (h,c)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user