我是靠谱客的博主 忧郁指甲油,最近开发中收集的这篇文章主要介绍torchvision学习(2)——datasets、models(加载数据、调用模型)Pytorch框架torchvision库学习前言总结,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

Pytorch框架torchvision库学习


torchvision是pytorch的一个图形库,它包含了torchvision.datasets、torchvision.models、torchvision.transforms、torchvision.utils四部分。
1、torchvision.datasets: 一些数据集。
2、torchvision.models: 常见卷积网络模型。
3、torchvision.transforms: 数据预处理、图片变换等操作。详细介绍转:http://t.csdn.cn/CmWNj
4、torchvision.utils: 其他函数。

文章目录

  • Pytorch框架torchvision库学习
  • 前言
    • 1.torchvision.datasets
    • 2.DataLoader
    • 3.torchvision.models
      • 导入模型
      • 改模型默认下载目录
      • 修改模型
      • 保存模型
      • 加载模型
    • 4.torchvision.utils
  • 总结


前言

最近在学习pytorch,总结一下。


1.torchvision.datasets

datasets这个包有很多数据集,比如MINIST、COCO、CIFAR10 and CIFAR100、LSUN 、Classification、ImageFolder、Imagenet-12、STL10。torchvision.datasets中的数据集封装都是torch.utils.data.Dataset子类,它们都实现了__getitem__ 和 __len__方法,都可以用DataLoader进行数据加载。

  torchvision.datasets.MNIST(root,train = True,transform = None,target_transform = None,download = False

参数介绍:
root:数据集的根目录
train:如果为True,训练集,否则是测试集
download:如果为true,根目录没有数据集就会自动在这个目录下载。
transform:数据集预处理,比如归一化当图形转换类的操作
target_transform:接收目标并对其进行转换的函数/转换。

MNIST数据集示例

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
# 数据预处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081))])
# 训练集
train_dataset = datasets.MNIST(root='../data/mnist',train=True, download=True, transform=transform)
# 测试集
test_dataset=datasets.MNIST(root='../data/mnist',train=False, download=True, transform=transform)
# 数据集加载器 
# train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
# test_loader=DataLoader(test_dataset,shuffle=False,batch_size=batch_size)

2.DataLoader

DataLoader数据加载器,PyTorch数据读取重要接口,用PyTorch架构来训练模型基本都会用到该接口,把数据分块送入模型进行训练。

from torch.utils.data import DataLoader

train_loader= DataLoader(dataset = train_dataset,  # 数据加载
                        batch_size = 4,    # 送入多少张图片
                        shuffle = True,    #对原有数据排序是否打乱
                        num_workers = 0,   #是否进行多进程加载数据设置
                        drop_last = False) #最后的数据组不成一个batch_size 是否丢弃

参数:
dataset:数据加载
batch_size :送入多少张图片
shuffle :是否打乱数据
sampler :指定数据加载中使用的索引/键的序列
batch_sampler = None,#和sampler类似
num_workers :是否进行多进程加载数据设置
collat​​e_fn = None,#是否合并样本列表以形成一小批Tensor
pin_memory :数据加载器会在返回之前将Tensors复制到CUDA固定内存
drop_last :最后的数据组不成一个batch_size 是否丢弃

3.torchvision.models

models包含以下模型:
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNetV2
MobileNetV3
ResNeXt
Wide ResNet
MNASNet
EfficientNet
RegNet

导入模型

import torchvision.models as models
 
#alexnet = models.alexnet(pretrained=True)  # 加载预训练权重
alexnet = models.alexnet()   # AlexNet   不加载
vgg16 = models.vgg16()       # VGG16
resnet18 = models.resnet18() # ResetNet模型
print(vgg16 )   # 打印模型

改模型默认下载目录

import os
os.environ['TORCH_HOME']='E:/Data/torch-model'

修改模型

import torchvision.models as models
vgg16 = models.vgg16()       # VGG16
# 在classifier层添加add_linear
vgg16.classifier.add_module("add_linear",Linear(1000,10))
# 在classifier层修改add_linear参数
vgg16_false.classifier[6]=Linear(4096,10)

保存模型

path = "D:/code/text/model1.pth"
    #torch.save(model,path)
    torch.save(model.state_dict(),path)   # 保存模型

加载模型

解决pytorch加载模型报错TypeError: ‘collections.OrderedDict‘ object is not callable
# 错误原因:之前保存网络时用的方法是torch.save(model, ‘Nei.pkl’),这样保存下来的Net.pkl是一个状态字典,而不是模型本身,也就是说Net.pkl中保存的只是网络的参数,而没有网络结构。

    model = torchvision.models.vgg16(pretrained=False)
    model.load_state_dict(torch.load('D:/code/text/model1.pth')) # 导入网络的参数

4.torchvision.utils

拼接图片
组成图像的网络,将多张图片组合成一张图片

torchvision.utils.make_grid(images)

参数:
tensor:4D张量,形状为(B x C x H x W),图像列表
nrow:每行的图片数量,默认值为8
padding:相邻图像之间的间隔。默认值为2
normalize:如果为True,则把图像的像素值通过range指定的最大值和最小值归一化到0-1。默认为False
range:元组,用于指定最大值和最小值。默认使用图像像素的最大最小值。
sacle_each:如果为True,就单独对每张图像进行normalize;如果是False,统一对所有图像进行normalize。默认为Flase
pad_value:float,上述padding会使得图像之间留出空隙,默认为0

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

train_dataset = datasets.MNIST(root='../data/mnist',train=True,transform=data_tf,download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=False)

images, labels = next(iter(train_loader))   # batch_size 高 长 宽
# 组成图像的网络,将多张图片组合成一张图片
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)

def cv_show(name, img):  # 长宽高
    cv2.imshow(name, img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

cv_show('image', img)

在这里插入图片描述

保存图片

torchvision.utils.save_image(img, imgPath)

总结

未完待续,,,

最后

以上就是忧郁指甲油为你收集整理的torchvision学习(2)——datasets、models(加载数据、调用模型)Pytorch框架torchvision库学习前言总结的全部内容,希望文章能够帮你解决torchvision学习(2)——datasets、models(加载数据、调用模型)Pytorch框架torchvision库学习前言总结所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(53)

评论列表共有 0 条评论

立即
投稿
返回
顶部