我是靠谱客的博主 朴素灯泡,最近开发中收集的这篇文章主要介绍pytorch resnet 全连接层linear报错:RuntimeError: mat1 dim 1 must match mat2 dim 0,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

Traceback (most recent call last):
  File "/home/user1/pjs/frvt_pytorch/batch_run/2branch_alter_1update_2pfc_MMD_ori_auto/recognition/arcface_torch/tools/visualize.py", line 276, in <module>
    mask = grad_cam(input, target_index)
  File "/home/user1/pjs/frvt_pytorch/batch_run/2branch_alter_1update_2pfc_MMD_ori_auto/recognition/arcface_torch/tools/visualize.py", line 104, in __call__
    features, output = self.extractor(input.cuda())
  File "/home/user1/pjs/frvt_pytorch/batch_run/2branch_alter_1update_2pfc_MMD_ori_auto/recognition/arcface_torch/tools/visualize.py", line 62, in __call__
    target_activations, output  = self.feature_extractor(x)
  File "/home/user1/pjs/frvt_pytorch/batch_run/2branch_alter_1update_2pfc_MMD_ori_auto/recognition/arcface_torch/tools/visualize.py", line 43, in __call__
    x = module(x)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/functional.py", line 1692, in linear
    output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

分析:
报错代码:

x = self.fc(x)

全连接层结构:

self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) # block.expansion=1, fc_scale=7*7,num_features=512

输入和全连接层权重相乘,输入的dim=1维要和全连接层权重转置后的dim=0维相等。
原始输入的shape: torch.Size([1, 512, 14, 14]),dim1=512
全连接层weight转置前的shape: torch.Size([512, 25088]),.t()转置后为(25088,512),dim0=25088

解决:对输入x做shape变换,将其dim=1维调整为25088,使之满足维度相等关系

x = x.view(-1,25088) # 25088 = 512 * 7 * 7

https://stackoverflow.com/questions/64868040/runtimeerror-mat1-dim-1-must-match-mat2-dim-0-whenever-i-run-modelimages

最后

以上就是朴素灯泡为你收集整理的pytorch resnet 全连接层linear报错:RuntimeError: mat1 dim 1 must match mat2 dim 0的全部内容,希望文章能够帮你解决pytorch resnet 全连接层linear报错:RuntimeError: mat1 dim 1 must match mat2 dim 0所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(69)

评论列表共有 0 条评论

立即
投稿
返回
顶部