概述
Pytorch框架torchvision库学习
torchvision是pytorch的一个图形库,它包含了torchvision.datasets、torchvision.models、torchvision.transforms、torchvision.utils四部分。
1、torchvision.datasets: 一些数据集。
2、torchvision.models: 常见卷积网络模型。
3、torchvision.transforms: 数据预处理、图片变换等操作。详细介绍转:http://t.csdn.cn/CmWNj
4、torchvision.utils: 其他函数。
文章目录
- Pytorch框架torchvision库学习
- 前言
- 1.torchvision.datasets
- 2.DataLoader
- 3.torchvision.models
- 导入模型
- 改模型默认下载目录
- 修改模型
- 保存模型
- 加载模型
- 4.torchvision.utils
- 总结
前言
最近在学习pytorch,总结一下。
1.torchvision.datasets
datasets这个包有很多数据集,比如MINIST、COCO、CIFAR10 and CIFAR100、LSUN 、Classification、ImageFolder、Imagenet-12、STL10。torchvision.datasets中的数据集封装都是torch.utils.data.Dataset子类,它们都实现了__getitem__ 和 __len__方法,都可以用DataLoader进行数据加载。
torchvision.datasets.MNIST(root,train = True,transform = None,target_transform = None,download = False )
参数介绍:
root:数据集的根目录
train:如果为True,训练集,否则是测试集
download:如果为true,根目录没有数据集就会自动在这个目录下载。
transform:数据集预处理,比如归一化当图形转换类的操作
target_transform:接收目标并对其进行转换的函数/转换。
MNIST数据集示例
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
# 数据预处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081))])
# 训练集
train_dataset = datasets.MNIST(root='../data/mnist',train=True, download=True, transform=transform)
# 测试集
test_dataset=datasets.MNIST(root='../data/mnist',train=False, download=True, transform=transform)
# 数据集加载器
# train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
# test_loader=DataLoader(test_dataset,shuffle=False,batch_size=batch_size)
2.DataLoader
DataLoader数据加载器,PyTorch数据读取重要接口,用PyTorch架构来训练模型基本都会用到该接口,把数据分块送入模型进行训练。
from torch.utils.data import DataLoader
train_loader= DataLoader(dataset = train_dataset, # 数据加载
batch_size = 4, # 送入多少张图片
shuffle = True, #对原有数据排序是否打乱
num_workers = 0, #是否进行多进程加载数据设置
drop_last = False) #最后的数据组不成一个batch_size 是否丢弃
参数:
dataset:数据加载
batch_size :送入多少张图片
shuffle :是否打乱数据
sampler :指定数据加载中使用的索引/键的序列
batch_sampler = None,#和sampler类似
num_workers :是否进行多进程加载数据设置
collate_fn = None,#是否合并样本列表以形成一小批Tensor
pin_memory :数据加载器会在返回之前将Tensors复制到CUDA固定内存
drop_last :最后的数据组不成一个batch_size 是否丢弃
3.torchvision.models
models包含以下模型:
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNetV2
MobileNetV3
ResNeXt
Wide ResNet
MNASNet
EfficientNet
RegNet
导入模型
import torchvision.models as models
#alexnet = models.alexnet(pretrained=True) # 加载预训练权重
alexnet = models.alexnet() # AlexNet 不加载
vgg16 = models.vgg16() # VGG16
resnet18 = models.resnet18() # ResetNet模型
print(vgg16 ) # 打印模型
改模型默认下载目录
import os
os.environ['TORCH_HOME']='E:/Data/torch-model'
修改模型
import torchvision.models as models
vgg16 = models.vgg16() # VGG16
# 在classifier层添加add_linear
vgg16.classifier.add_module("add_linear",Linear(1000,10))
# 在classifier层修改add_linear参数
vgg16_false.classifier[6]=Linear(4096,10)
保存模型
path = "D:/code/text/model1.pth"
#torch.save(model,path)
torch.save(model.state_dict(),path) # 保存模型
加载模型
解决pytorch加载模型报错TypeError: ‘collections.OrderedDict‘ object is not callable
# 错误原因:之前保存网络时用的方法是torch.save(model, ‘Nei.pkl’),这样保存下来的Net.pkl是一个状态字典,而不是模型本身,也就是说Net.pkl中保存的只是网络的参数,而没有网络结构。
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load('D:/code/text/model1.pth')) # 导入网络的参数
4.torchvision.utils
拼接图片
组成图像的网络,将多张图片组合成一张图片
torchvision.utils.make_grid(images)
参数:
tensor:4D张量,形状为(B x C x H x W),图像列表
nrow:每行的图片数量,默认值为8
padding:相邻图像之间的间隔。默认值为2
normalize:如果为True,则把图像的像素值通过range指定的最大值和最小值归一化到0-1。默认为False
range:元组,用于指定最大值和最小值。默认使用图像像素的最大最小值。
sacle_each:如果为True,就单独对每张图像进行normalize;如果是False,统一对所有图像进行normalize。默认为Flase
pad_value:float,上述padding会使得图像之间留出空隙,默认为0
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(root='../data/mnist',train=True,transform=data_tf,download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=False)
images, labels = next(iter(train_loader)) # batch_size 高 长 宽
# 组成图像的网络,将多张图片组合成一张图片
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
def cv_show(name, img): # 长宽高
cv2.imshow(name, img)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv_show('image', img)
保存图片
torchvision.utils.save_image(img, imgPath)
总结
未完待续,,,
最后
以上就是忧郁指甲油为你收集整理的torchvision学习(2)——datasets、models(加载数据、调用模型)Pytorch框架torchvision库学习前言总结的全部内容,希望文章能够帮你解决torchvision学习(2)——datasets、models(加载数据、调用模型)Pytorch框架torchvision库学习前言总结所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复