概述
'''
# rnn 和 lstm 在定义上差不太多
# lstm在输入的时候可以选择是不是输入h_0和c_0
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
'''
# 这里是一段lstm的使用的代码
class Classfication_Model(nn.Module):
def __init__(self):
super(Classfication_Model, self).__init__()
self.hidden_size = 128
self.embedding_dim = 200
self.number_layer = 4
self.bidirectional = True
self.bi_number = 2 if self.bidirectional else 1
self.dropout = 0.5
self.embedding = nn.Embedding(num_embeddings=len(model.index_to_key)+200
, embedding_dim=self.embedding_dim)
self.lstm = nn.LSTM(input_size=self.embedding_dim
, hidden_size=self.hidden_size
, num_layers=self.number_layer
, dropout=self.dropout
, bidirectional=self.bidirectional)
self.fc = nn.Sequential(
nn.Linear(self.hidden_size*self.bi_number,20)
, nn.ReLU()
, nn.Linear(20,2)
)
def init_hidden_state(self, batch_size):
h_0 = torch.rand(batch_size, self.number_layer * self.bi_number, self.hidden_size).to(device)
c_0 = torch.rand(batch_size, self.number_layer * self.bi_number, self.hidden_size).to(device)
return (h_0, c_0)
def forward(self, input, hidden):
input_embeded = self.embedding(input)
input_embeded = input_embeded.permute(1, 0, 2) # 调整为:[sqe_len,batch_size,embedding_dim]
hidden = [x.permute(1,0,2).contiguous() for x in hidden]
_, (h_n, c_n) = self.lstm(input_embeded, hidden)
out = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), -1)# 2,256
out = self.fc(out)
return out
def train(epoch):
ds = corpus_dataset(train_model=True, max_sentence_length=50,train_set=train_set,test_set=test_set)
train_dataloader = DataLoader(ds, batch, shuffle=True,num_workers=5)
total_loss = 0
classfication_model.train()
# hidden = classfication_model.init_hidden_state(batch) DataParallel时出错
# hidden = classfication_model.module.init_hidden_state(batch) 这个batch_size设置是死的
for idx, (input, target) in enumerate(train_dataloader):
target = target.to(device)
input = input.to(device)
optimizer.zero_grad()
# 进行初始化获得h_0与c_0
# 这是是在每个样本中都会进行
hidden = classfication_model.module.init_hidden_state(len(input))# 这个batch_size设置是活的
output = classfication_model(input, hidden)
loss = criterion(output, target) # traget需要是[0,9],不能是[1-10]
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"epoch:{epoch} ###### total_loss:{total_loss:.6f}")
最后
以上就是现代发箍为你收集整理的【学习笔记】Pytorch LSTM/RNN 代码的全部内容,希望文章能够帮你解决【学习笔记】Pytorch LSTM/RNN 代码所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复