我是靠谱客的博主 聪明大叔,最近开发中收集的这篇文章主要介绍PyTorch笔记,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

目录

1.  nn.Sequential

2. torch.flatten(input, start_dim=0, end_dim=-1)

 3.model.train()和model.eval()

4.optimizer优化器

5.Pytorch 提供了很多不同的参数初始化函数

6.json.dumps() json.loads()

7.pytorch中model eval和torch no grad()的区别

8.load_state_dict 

9.python 中参数*args, **kwargs


 

1.  nn.Sequential

顺序容器。模块将按照在构造函数中传递的顺序添加到它。或者,也可以传入模块的有序字典。

# Example of using Sequential
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

2. torch.flatten(input, start_dim=0, end_dim=-1)

input: 一个 tensor,也就是要被展平 tensor。

start_dim: 起始维度。

end_dim: 结束维度。

t = torch.tensor([[[1, 2, 2, 1],
                   [3, 4, 4, 3],
                   [1, 2, 3, 4]],
                  [[5, 6, 6, 5],
                   [7, 8, 8, 7],
                   [5, 6, 7, 8]]])
print(t, t.shape)
运行后:
tensor([[[1, 2, 2, 1],
         [3, 4, 4, 3],
         [1, 2, 3, 4]],

        [[5, 6, 6, 5],
         [7, 8, 8, 7],
         [5, 6, 7, 8]]])
torch.Size([2, 3, 4])

x = torch.flatten(t, start_dim=1) # 第1维到最后一维展平
print(x, x.shape)

y = torch.flatten(t, start_dim=0, end_dim=1) # 第0维到第1维合并
print(y, y.shape)
运行后:
tensor([[1, 2, 2, 1, 3, 4, 4, 3, 1, 2, 3, 4],
        [5, 6, 6, 5, 7, 8, 8, 7, 5, 6, 7, 8]]) torch.Size([2, 12])
tensor([[1, 2, 2, 1],
        [3, 4, 4, 3],
        [1, 2, 3, 4],
        [5, 6, 6, 5],
        [7, 8, 8, 7],
        [5, 6, 7, 8]]) torch.Size([6, 4])

参考自https://dilthey.cnblogs.com/

 3.model.train()和model.eval()

分别在训练和测试中都要写,它们的作用如下:

(1). model.train()
启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为True
(2). model.eval()
不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False

总结

(1)在训练模块中千万不要忘了写model.train()

(2)在评估(或测试)模块千万不要忘了写model.eval()

(3)在没有涉及到BN与Dropout的网络,这两个函数存在对于网络性能的影响,好坏不定

4.optimizer优化器

[pytorch]几种optimizer优化器的使用

pytorch中的Optimizer的灵活运用 

5.Pytorch 提供了很多不同的参数初始化函数

  • torch.nn.init.constant_(tensor,val)
  • torch.nn.init.normal_(tensor,mean=0,std=1)
  • torch.nn.init.xavier_uniform_(tensor,gain=1)

 注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

参考:博客地址

6.json.dumps() json.loads()

json.dumps()用于将字典形式的数据转化为字符串,json.loads()用于将字符串形式的数据转化为字典

import json

data = {
    'name' : 'Connor',
    'sex' : 'boy',
    'age' : 26
}
print(data)
data1=json.dumps(data)
print(data1)
data2=json.loads(data1)
print(data2)
print(type(data))#输出原始数据格式
print(type(data1))#输出经过json.dumps的数据格式
print(type(data2))#输出经过json.loads的数据格式

#对应输出结果
{'name': 'Connor', 'sex': 'boy', 'age': 26}
{"name": "Connor", "sex": "boy", "age": 26}
{'name': 'Connor', 'sex': 'boy', 'age': 26}
<class 'dict'>
<class 'str'>
<class 'dict'>

如果直接将dict类型的数据写入json文件中会发生报错,因此在将数据写入时需要用到json.dump(),

json.load()用于从json文件中读取数据

with open('data3.json','a',encoding='utf-8') as f: 
    f.write(data1)
    f.close()
data4=json.load(open('data3.json'))#json.load()用于读取json数据
print(data4)
#打印结果
{'name': 'Connor', 'sex': 'boy', 'age': 26}

参考:https://www.cnblogs.com/ConnorShip/p/9744223.html

7.pytorch中model eval和torch no grad()的区别

在PyTorch中进行validation时,会使用model.eval()切换到测试模式,在该模式下,

  • 主要用于通知dropout层和batchnorm层在train和val模式间切换
    • train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。
    • val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
  • 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反传(backprobagation)
  • with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用,具体行为就是停止gradient计算,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。
  • 参考:博客地址

8.load_state_dict 

(1). load(self)

这个函数会递归地对模型进行参数恢复,其中的_load_from_state_dict的源码附在文末。

首先我们需要明确state_dict这个变量表示你之前保存的模型参数序列,而_load_from_state_dict函数中的local_state表示你的代码中定义的模型的结构。

那么_load_from_state_dict的作用简单理解就是假如我们现在需要对一个名为conv.weight的子模块做参数恢复,那么就以递归的方式先判断conv是否在state__dictlocal_state中,如果不在就把conv添加到unexpected_keys中去,否则递归的判断conv.weight是否存在,如果都存在就执行param.copy_(input_param),这样就完成了conv.weight的参数拷贝。

(2).if strict:

这个部分的作用是判断上面参数拷贝过程中是否有unexpected_keys或者missing_keys,如果有就报错,代码不能继续执行。当然,如果strict=False,则会忽略这些细节。

def load_state_dict(self, state_dict, strict=True):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(self)

    if strict:
        error_msg = ''
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0, 'Unexpected key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in unexpected_keys)))
        if len(missing_keys) > 0:
            error_msgs.insert(
                0, 'Missing key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in missing_keys)))

    if len(error_msgs) > 0:
        raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
                           self.__class__.__name__, "nt".join(error_msgs)))

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

参考:https://zhuanlan.zhihu.com/p/98563721

           https://blog.csdn.net/strive_for_future/article/details/83240081 

9.python 中参数*args, **kwargs

def foo(* args, ** kwargs):
print ' args = ',  args
print ' kwargs = ',  kwargs
print '---------------------------------------'

if __name__ == '__main__':
foo(1,2,3,4)
foo(a=1,b=2,c=3)
foo(1,2,3,4, a=1,b=2,c=3)
foo('a', 1, None, a=1, b='2', c=3)

输出结果如下:

args =  (1, 2, 3, 4) 
kwargs =  {} 

args =  () 
kwargs =  {'a': 1, 'c': 3, 'b': 2} 

args =  (1, 2, 3, 4) 
kwargs =  {'a': 1, 'c': 3, 'b': 2} 

args =  ('a', 1, None) 
kwargs =  {'a': 1, 'c': 3, 'b': '2'}

可以看到,这两个是python中的可变参数。

*args表示任何多个无名参数,它是一个tuple;

**kwargs表示关键字参数,它是一个 dict。**kwargs 允许你将不定长度的键值对, 作为参数传递给一个函数。 如果你想要在一个函数里处理带名字的参数, 你应该使用**kwargs。

并且同时使用*args和**kwargs时,必须*args参数列要在**kwargs前,像foo(a=1, b='2', c=3, a', 1, None, 这样调用的话,会提示语法错误“SyntaxError: non-keyword arg after keyword arg”。

 

最后

以上就是聪明大叔为你收集整理的PyTorch笔记的全部内容,希望文章能够帮你解决PyTorch笔记所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(58)

评论列表共有 0 条评论

立即
投稿
返回
顶部