我是靠谱客的博主 粗暴面包,最近开发中收集的这篇文章主要介绍pytorch 获取模型参数_Pytorch获取模型参数情况的方法,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

分享人工智能技术干货,专注深度学习与计算机视觉领域!

相较于Tensorflow,Pytorch一开始就是以动态图构建神经网络图的,其获取模型参数的方法也比较容易,既可以根据其内建接口自己写代码获取模型参数情况,也可以借助第三方库来获取模型参数情况,下面,就让我们一起来了解Pytorch获取模型参数情况的这两种方法!

Pytorch依据其内建接口自己写代码获取模型参数情况,我们主要是借助该框架提供的模型parameters()接口并获取对应参数的size来实现的,对于该参数是否属于可训练参数,那么我们可以依据Pytorch提供的requires_grad标志位来进行判断,具体方法如下代码所示:

# 定义总参数量、可训练参数量及非可训练参数量变量

Total_params = 0

Trainable_params = 0

NonTrainable_params = 0

# 遍历model.parameters()返回的全局参数列表

for param in model.parameters():

mulValue = np.prod(param.size()) # 使用numpy prod接口计算参数数组所有元素之积

Total_params += mulValue # 总参数量

if param.requires_grad:

Trainable_params += mulValue # 可训练参数量

else:

NonTrainable_params += mulValue # 非可训练参数量

print(f'Total p

最后

以上就是粗暴面包为你收集整理的pytorch 获取模型参数_Pytorch获取模型参数情况的方法的全部内容,希望文章能够帮你解决pytorch 获取模型参数_Pytorch获取模型参数情况的方法所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(44)

评论列表共有 0 条评论

立即
投稿
返回
顶部