概述
分享人工智能技术干货,专注深度学习与计算机视觉领域!
相较于Tensorflow,Pytorch一开始就是以动态图构建神经网络图的,其获取模型参数的方法也比较容易,既可以根据其内建接口自己写代码获取模型参数情况,也可以借助第三方库来获取模型参数情况,下面,就让我们一起来了解Pytorch获取模型参数情况的这两种方法!
Pytorch依据其内建接口自己写代码获取模型参数情况,我们主要是借助该框架提供的模型parameters()接口并获取对应参数的size来实现的,对于该参数是否属于可训练参数,那么我们可以依据Pytorch提供的requires_grad标志位来进行判断,具体方法如下代码所示:
# 定义总参数量、可训练参数量及非可训练参数量变量
Total_params = 0
Trainable_params = 0
NonTrainable_params = 0
# 遍历model.parameters()返回的全局参数列表
for param in model.parameters():
mulValue = np.prod(param.size()) # 使用numpy prod接口计算参数数组所有元素之积
Total_params += mulValue # 总参数量
if param.requires_grad:
Trainable_params += mulValue # 可训练参数量
else:
NonTrainable_params += mulValue # 非可训练参数量
print(f'Total p
最后
以上就是粗暴面包为你收集整理的pytorch 获取模型参数_Pytorch获取模型参数情况的方法的全部内容,希望文章能够帮你解决pytorch 获取模型参数_Pytorch获取模型参数情况的方法所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复