复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31在这里插入代码import torch import torch.utils.data as Data#批处理模块 torch.manual_seed(1) # reproducible BATCH_SIZE = 5#每小批5个 # BATCH_SIZE = 8 x = torch.linspace(1, 10, 10) # this is x data (torch tensor) y = torch.linspace(10, 1, 10) # this is y data (torch tensor) torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader(#使训练分批 dataset=torch_dataset, # torch TensorDataset format batch_size=BATCH_SIZE, # mini batch size shuffle=True, # 训练时随机打乱数据再抽样 num_workers=2, # subprocesses for loading data ) def show_batch(): for epoch in range(3): # 训练整个数据集三次(每次都是拆分成小组训练) for step, (batch_x, batch_y) in enumerate(loader): # for each training step # train your data... print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy()) if __name__ == '__main__': show_batch()片
视频传送门
结果图:
最后
以上就是疯狂面包最近收集整理的关于pytorch:分批训练的全部内容,更多相关pytorch内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复