概述
unable to get repr for class 'torch.tensor'
出错代码:
batch_conf.gather(1, conf_t.view(-1,1))
最近码代码使用pytorch遇到如题所示的问题,查遍Google百度,大多是说运算时维度不符,但是我找遍代码也没发现有这个错误。一段时间后才发现,网络参数保存的是torch.float32类型,而我输入的数据是torch.float64类型,将数据类型更改为torch.float32,问题解决。
我是因为是用别人的训练代码,没有改完,除了bug,导致最后输出的神经元个数(类别数)小于给的label-1(从0开始)的值。必须是神经元个数即类别数要完全等于maximum label value-1,比如分成10类,label最大只能是9,超过9的情况出现就会出现题目中的错误,然后pytorch还没有提示。。。
网上还有别的情况:
https://blog.csdn.net/jizhidexiaoming/article/details/109442337
问题描述:计算BCE Loss
使用pytorch接口
self.bce_loss = nn.BCELoss()
self.bce_loss(pred_cls, tcls)
问题原因:pred_cls没有归一化的0到1之间。
解决办法:
self.bce_loss(torch.sigmoid(pred_cls), tcls)
最后
以上就是狂野小虾米为你收集整理的unable to get repr for class ‘torch.tensor‘unable to get repr for class 'torch.tensor'的全部内容,希望文章能够帮你解决unable to get repr for class ‘torch.tensor‘unable to get repr for class 'torch.tensor'所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复