我是靠谱客的博主 傲娇眼睛,最近开发中收集的这篇文章主要介绍pytorch中几种tensor掩码的获取方法(含代码),觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

方式一: 直接取布尔值

输入:

target = torch.Tensor([1,0,0,2,0,0,3])
mask = (target > 0)
masked_target = target[mask]

print(target)
print(mask)
print(masked_target)

输入:

target = torch.Tensor([1,0,0,2,0,0,3])
mask = target.ge(0)
masked_target = torch.masked_select(target, mask)

print(target)
print(mask)
print(masked_target)

输出:

tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([ True, False, False,  True, False, False,  True])
tensor([1., 2., 3.])

 

方式二:自己设置ByteTensor作为掩码

输入:

target = torch.Tensor([1,0,0,2,0,0,3])
mask = torch.ByteTensor([1,0,0,1,0,0,0])
masked_target = torch.masked_select(target, mask)

print(target)
print(masked_target)

输出:

tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([1., 2.])

最后

以上就是傲娇眼睛为你收集整理的pytorch中几种tensor掩码的获取方法(含代码)的全部内容,希望文章能够帮你解决pytorch中几种tensor掩码的获取方法(含代码)所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(64)

评论列表共有 0 条评论

立即
投稿
返回
顶部