fixed incorrect method call - torch.stack should be torch.cat, and fixed incorrect variable name (#338)
Co-authored-by: Matus <test@test.com>
This commit is contained in:
parent
b8b6b02aa4
commit
c2fc05e063
@ -1614,14 +1614,14 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" def forward(self, input, state):\n",
|
" def forward(self, input, state):\n",
|
||||||
" h,c = state\n",
|
" h,c = state\n",
|
||||||
" h = torch.stack([h, input], dim=1)\n",
|
" h = torch.cat([h, input], dim=1)\n",
|
||||||
" forget = torch.sigmoid(self.forget_gate(h))\n",
|
" forget = torch.sigmoid(self.forget_gate(h))\n",
|
||||||
" c = c * forget\n",
|
" c = c * forget\n",
|
||||||
" inp = torch.sigmoid(self.input_gate(h))\n",
|
" inp = torch.sigmoid(self.input_gate(h))\n",
|
||||||
" cell = torch.tanh(self.cell_gate(h))\n",
|
" cell = torch.tanh(self.cell_gate(h))\n",
|
||||||
" c = c + inp * cell\n",
|
" c = c + inp * cell\n",
|
||||||
" out = torch.sigmoid(self.output_gate(h))\n",
|
" out = torch.sigmoid(self.output_gate(h))\n",
|
||||||
" h = outgate * torch.tanh(c)\n",
|
" h = out * torch.tanh(c)\n",
|
||||||
" return h, (h,c)"
|
" return h, (h,c)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -2362,4 +2362,4 @@
|
|||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 2
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user