我是靠谱客的博主 粗心睫毛膏,最近开发中收集的这篇文章主要介绍Pytorch学习第五讲:LSTM网络实现,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

这里主要记录一下lstm网络的pytorch实现:

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F


class my_lstm(nn.Module):
    def __init__(self):
        super(my_lstm, self).__init__()
        self.conv_f = nn.Sequential(
            nn.Conv2d(32+32, 32, 3, 1, 1),
            nn.Sigmoid()
        )
        self.conv_i = nn.Sequential(
            nn.Conv2d(32+32, 32, 3, 1, 1),
            nn.Sigmoid()
        )
        self.conv_g = nn.Sequential(
            nn.Conv2d(32+32, 32, 3, 1, 1),
            nn.Tanh()
        )
        self.conv_o = nn.Sequential(
            nn.Conv2d(32+32, 32, 3, 1, 1),
            nn.Sigmoid
        )

    def forward(self, input):
        batch_size, row, col = input.size(0), input.size(2), input.size(3)
        h = Variable(torch.zeros(batch_size, 32, row, col)).cuda()
        c = Variable(torch.zeros(batch_size, 32, row, col)).cuda()
        x = torch.cat((input, h), 1)
        f = self.conv_f(x)
        i = self.conv_i(x)
        g = self.conv_g(x)
        o = self.conv_o(x)
        c = c * f + i * g
        h = o * F.tanh(c)
        return h

最后

以上就是粗心睫毛膏为你收集整理的Pytorch学习第五讲:LSTM网络实现的全部内容,希望文章能够帮你解决Pytorch学习第五讲:LSTM网络实现所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(47)

评论列表共有 0 条评论

立即
投稿
返回
顶部