概述
方式一: 直接取布尔值
输入:
target = torch.Tensor([1,0,0,2,0,0,3])
mask = (target > 0)
masked_target = target[mask]
print(target)
print(mask)
print(masked_target)
输入:
target = torch.Tensor([1,0,0,2,0,0,3])
mask = target.ge(0)
masked_target = torch.masked_select(target, mask)
print(target)
print(mask)
print(masked_target)
输出:
tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([ True, False, False, True, False, False, True])
tensor([1., 2., 3.])
方式二:自己设置ByteTensor作为掩码
输入:
target = torch.Tensor([1,0,0,2,0,0,3])
mask = torch.ByteTensor([1,0,0,1,0,0,0])
masked_target = torch.masked_select(target, mask)
print(target)
print(masked_target)
输出:
tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([1., 2.])
最后
以上就是傲娇眼睛为你收集整理的pytorch中几种tensor掩码的获取方法(含代码)的全部内容,希望文章能够帮你解决pytorch中几种tensor掩码的获取方法(含代码)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复