这里主要记录一下lstm网络的pytorch实现:
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38import torch import torch.nn as nn from torch.autograd import Variable import torch.nn.functional as F class my_lstm(nn.Module): def __init__(self): super(my_lstm, self).__init__() self.conv_f = nn.Sequential( nn.Conv2d(32+32, 32, 3, 1, 1), nn.Sigmoid() ) self.conv_i = nn.Sequential( nn.Conv2d(32+32, 32, 3, 1, 1), nn.Sigmoid() ) self.conv_g = nn.Sequential( nn.Conv2d(32+32, 32, 3, 1, 1), nn.Tanh() ) self.conv_o = nn.Sequential( nn.Conv2d(32+32, 32, 3, 1, 1), nn.Sigmoid ) def forward(self, input): batch_size, row, col = input.size(0), input.size(2), input.size(3) h = Variable(torch.zeros(batch_size, 32, row, col)).cuda() c = Variable(torch.zeros(batch_size, 32, row, col)).cuda() x = torch.cat((input, h), 1) f = self.conv_f(x) i = self.conv_i(x) g = self.conv_g(x) o = self.conv_o(x) c = c * f + i * g h = o * F.tanh(c) return h
最后
以上就是粗心睫毛膏最近收集整理的关于Pytorch学习第五讲:LSTM网络实现的全部内容,更多相关Pytorch学习第五讲内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复