我是靠谱客的博主 朴素灯泡,最近开发中收集的这篇文章主要介绍pytorch resnet 全连接层linear报错:RuntimeError: mat1 dim 1 must match mat2 dim 0,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
Traceback (most recent call last):
File "/home/user1/pjs/frvt_pytorch/batch_run/2branch_alter_1update_2pfc_MMD_ori_auto/recognition/arcface_torch/tools/visualize.py", line 276, in <module>
mask = grad_cam(input, target_index)
File "/home/user1/pjs/frvt_pytorch/batch_run/2branch_alter_1update_2pfc_MMD_ori_auto/recognition/arcface_torch/tools/visualize.py", line 104, in __call__
features, output = self.extractor(input.cuda())
File "/home/user1/pjs/frvt_pytorch/batch_run/2branch_alter_1update_2pfc_MMD_ori_auto/recognition/arcface_torch/tools/visualize.py", line 62, in __call__
target_activations, output = self.feature_extractor(x)
File "/home/user1/pjs/frvt_pytorch/batch_run/2branch_alter_1update_2pfc_MMD_ori_auto/recognition/arcface_torch/tools/visualize.py", line 43, in __call__
x = module(x)
File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/functional.py", line 1692, in linear
output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0
分析:
报错代码:
x = self.fc(x)
全连接层结构:
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) # block.expansion=1, fc_scale=7*7,num_features=512
输入和全连接层权重相乘,输入的dim=1维要和全连接层权重转置后的dim=0维相等。
原始输入的shape: torch.Size([1, 512, 14, 14]),dim1=512
全连接层weight转置前的shape: torch.Size([512, 25088]),.t()转置后为(25088,512),dim0=25088
解决:对输入x做shape变换,将其dim=1维调整为25088,使之满足维度相等关系
x = x.view(-1,25088) # 25088 = 512 * 7 * 7
https://stackoverflow.com/questions/64868040/runtimeerror-mat1-dim-1-must-match-mat2-dim-0-whenever-i-run-modelimages
最后
以上就是朴素灯泡为你收集整理的pytorch resnet 全连接层linear报错:RuntimeError: mat1 dim 1 must match mat2 dim 0的全部内容,希望文章能够帮你解决pytorch resnet 全连接层linear报错:RuntimeError: mat1 dim 1 must match mat2 dim 0所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复