我是靠谱客的博主 热心小伙,最近开发中收集的这篇文章主要介绍最常用的3种Pytorch tensor的维度变化方法,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

最近一周一直在搭建新的模型,学到了很多新的tensor维度变化操作,记录一下:

a = torch.tensor([1,2,3])

1、None增加一维:

b = a[None, :]
# 改成两维(一个逗号),增加到第一维
>>> b == tensor([[1, 2, 3]])
c = a[:, None]
# 改成两维(一个逗号),增加到第二维
>>> c == tensor([[1],[2],[3]])
d = a[:,None, None]
# 改成三维(两个逗号),增加到第二、三维
>>> d == tensor([[[1]],[[2]],[[3]]])

可以看出,None就是用来加[]的,一个None比原来多一维,n个逗号代表最后的向量是n+1维度的。
值得注意的是,None这种加[]不会改变原向量a,并且返回的新向量的属性也是默认tensor属性而不是a的,比如设置a的requires_grad=True,b c d还是默认的requires_grad=False。

2、view(*args) 改变形状

这是一个很有意思的方法,做的事情和None有点类似,但是它可以压缩向量维度。就是把一个tensor的元素先按行列依次展平,然后再reshape成view中传入参数的形状。
其中-1代表自适应。
例如:
1) tensor.view(-1) 代表将tensor展成自适应一维。(注意,如果值view一维,那么传参要不是-1,要不就是所有元素个数。因此view一维建议全部用view(-1)表示)。
2) tensor.view(1,-1) 代表将tensor变成2维,其中第一维是自适应。
3) tensor.view(1, 1,-1) 代表将tensor变成3维,其中第一维是自适应。
…依次类推,如果设置的张量shape没法整除,那么程序会抛出RuntimeError错误,可以后期捕捉。

3、squeeze()和unsqueeze(*args)

tensor.squeeze()不用传参,代表将张量中所有shape1的维数去掉,用于节省内存或者满足强迫症需求。注意的是suqeeze方法去的是所有的shape1的维数。

  1. tensor.unsqueeze(0) 代表在第0维加上一个维数为1的
  2. tensor.unsqueeze(1) 代表在第1维加上一个维数为1的

最后

以上就是热心小伙为你收集整理的最常用的3种Pytorch tensor的维度变化方法的全部内容,希望文章能够帮你解决最常用的3种Pytorch tensor的维度变化方法所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(41)

评论列表共有 0 条评论

立即
投稿
返回
顶部