概述
目录
1. nn.Sequential
2. torch.flatten(input, start_dim=0, end_dim=-1)
3.model.train()和model.eval()
4.optimizer优化器
5.Pytorch 提供了很多不同的参数初始化函数
6.json.dumps() json.loads()
7.pytorch中model eval和torch no grad()的区别
8.load_state_dict
9.python 中参数*args, **kwargs
1. nn.Sequential
顺序容器。模块将按照在构造函数中传递的顺序添加到它。或者,也可以传入模块的有序字典。
# Example of using Sequential
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
2. torch.flatten(input, start_dim=0, end_dim=-1)
input: 一个 tensor,也就是要被展平 tensor。
start_dim: 起始维度。
end_dim: 结束维度。
t = torch.tensor([[[1, 2, 2, 1], [3, 4, 4, 3], [1, 2, 3, 4]], [[5, 6, 6, 5], [7, 8, 8, 7], [5, 6, 7, 8]]]) print(t, t.shape) 运行后: tensor([[[1, 2, 2, 1], [3, 4, 4, 3], [1, 2, 3, 4]], [[5, 6, 6, 5], [7, 8, 8, 7], [5, 6, 7, 8]]]) torch.Size([2, 3, 4]) x = torch.flatten(t, start_dim=1) # 第1维到最后一维展平 print(x, x.shape) y = torch.flatten(t, start_dim=0, end_dim=1) # 第0维到第1维合并 print(y, y.shape) 运行后: tensor([[1, 2, 2, 1, 3, 4, 4, 3, 1, 2, 3, 4], [5, 6, 6, 5, 7, 8, 8, 7, 5, 6, 7, 8]]) torch.Size([2, 12]) tensor([[1, 2, 2, 1], [3, 4, 4, 3], [1, 2, 3, 4], [5, 6, 6, 5], [7, 8, 8, 7], [5, 6, 7, 8]]) torch.Size([6, 4])
参考自https://dilthey.cnblogs.com/
3.model.train()和model.eval()
分别在训练和测试中都要写,它们的作用如下:
(1). model.train()
启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为True
(2). model.eval()
不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False总结
(1)在训练模块中千万不要忘了写model.train()
(2)在评估(或测试)模块千万不要忘了写model.eval()
(3)在没有涉及到BN与Dropout的网络,这两个函数存在对于网络性能的影响,好坏不定
4.optimizer优化器
[pytorch]几种optimizer优化器的使用
pytorch中的Optimizer的灵活运用
5.Pytorch 提供了很多不同的参数初始化函数
- torch.nn.init.constant_(tensor,val)
- torch.nn.init.normal_(tensor,mean=0,std=1)
- torch.nn.init.xavier_uniform_(tensor,gain=1)
注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。
参考:博客地址
6.json.dumps() json.loads()
json.dumps()用于将字典形式的数据转化为字符串,json.loads()用于将字符串形式的数据转化为字典
import json data = { 'name' : 'Connor', 'sex' : 'boy', 'age' : 26 } print(data) data1=json.dumps(data) print(data1) data2=json.loads(data1) print(data2) print(type(data))#输出原始数据格式 print(type(data1))#输出经过json.dumps的数据格式 print(type(data2))#输出经过json.loads的数据格式 #对应输出结果 {'name': 'Connor', 'sex': 'boy', 'age': 26} {"name": "Connor", "sex": "boy", "age": 26} {'name': 'Connor', 'sex': 'boy', 'age': 26} <class 'dict'> <class 'str'> <class 'dict'>
如果直接将dict类型的数据写入json文件中会发生报错,因此在将数据写入时需要用到json.dump(),
json.load()用于从json文件中读取数据
with open('data3.json','a',encoding='utf-8') as f: f.write(data1) f.close() data4=json.load(open('data3.json'))#json.load()用于读取json数据 print(data4) #打印结果 {'name': 'Connor', 'sex': 'boy', 'age': 26}
参考:https://www.cnblogs.com/ConnorShip/p/9744223.html
7.pytorch中model eval和torch no grad()的区别
在PyTorch中进行validation时,会使用
model.eval()
切换到测试模式,在该模式下,
- 主要用于通知
dropout
层和batchnorm
层在train和val模式间切换
- 在
train
模式下,dropout
网络层会按照设定的参数p
设置保留激活单元的概率(保留概率=p);batchnorm
层会继续计算数据的mean和var等参数并更新。- 在
val
模式下,dropout
层会让所有的激活单元都通过,而batchnorm
层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。- 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反传(backprobagation)
- 而
with torch.no_grad()
则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用,具体行为就是停止gradient计算,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。- 参考:博客地址
8.load_state_dict
(1).
load(self)
这个函数会递归地对模型进行参数恢复,其中的
_load_from_state_dict
的源码附在文末。首先我们需要明确
state_dict
这个变量表示你之前保存的模型参数序列,而_load_from_state_dict
函数中的local_state
表示你的代码中定义的模型的结构。那么
_load_from_state_dict
的作用简单理解就是假如我们现在需要对一个名为conv.weight
的子模块做参数恢复,那么就以递归的方式先判断conv
是否在state__dict
和local_state
中,如果不在就把conv
添加到unexpected_keys
中去,否则递归的判断conv.weight
是否存在,如果都存在就执行param.copy_(input_param)
,这样就完成了conv.weight
的参数拷贝。(2).
if strict:
这个部分的作用是判断上面参数拷贝过程中是否有
unexpected_keys
或者missing_keys
,如果有就报错,代码不能继续执行。当然,如果strict=False
,则会忽略这些细节。def load_state_dict(self, state_dict, strict=True): missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(self) if strict: error_msg = '' if len(unexpected_keys) > 0: error_msgs.insert( 0, 'Unexpected key(s) in state_dict: {}. '.format( ', '.join('"{}"'.format(k) for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, 'Missing key(s) in state_dict: {}. '.format( ', '.join('"{}"'.format(k) for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format( self.__class__.__name__, "nt".join(error_msgs)))
1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
参考:https://zhuanlan.zhihu.com/p/98563721
https://blog.csdn.net/strive_for_future/article/details/83240081
9.python 中参数*args, **kwargs
def foo(* args, ** kwargs): print ' args = ', args print ' kwargs = ', kwargs print '---------------------------------------' if __name__ == '__main__': foo(1,2,3,4) foo(a=1,b=2,c=3) foo(1,2,3,4, a=1,b=2,c=3) foo('a', 1, None, a=1, b='2', c=3)
输出结果如下:
args = (1, 2, 3, 4) kwargs = {} args = () kwargs = {'a': 1, 'c': 3, 'b': 2} args = (1, 2, 3, 4) kwargs = {'a': 1, 'c': 3, 'b': 2} args = ('a', 1, None) kwargs = {'a': 1, 'c': 3, 'b': '2'}
可以看到,这两个是python中的可变参数。
*args表示任何多个无名参数,它是一个tuple;
**kwargs表示关键字参数,它是一个 dict。
**kwargs
允许你将不定长度的键值对, 作为参数传递给一个函数。 如果你想要在一个函数里处理带名字的参数, 你应该使用**kwargs。
并且同时使用*args和**kwargs时,必须*args参数列要在**kwargs前,像foo(a=1, b='2', c=3, a', 1, None, 这样调用的话,会提示语法错误“SyntaxError: non-keyword arg after keyword arg”。
最后
以上就是聪明大叔为你收集整理的PyTorch笔记的全部内容,希望文章能够帮你解决PyTorch笔记所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复