我是靠谱客的博主 优雅帽子,最近开发中收集的这篇文章主要介绍深入解析Tensor索引中的Indexing Multi-dimensional arrays问题写在前面前置知识问题描述理性分析写在最后,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

写在前面

最近小弟做了一些实验,但是发现我写的代码虽然能够跑通,但是对于gpu的利用率始终在一个比较低的水平,这就很难受,别人的代码2h就跑完了,我得10h,经过排查发现究其原因就是代码的并行化成都不高,在代码中使用了大量的for循环,没有采用矩阵运算,就导致计算非常的慢,于是最近在学习一些大神的代码,遇到了这个在Tensor中的Indexing Multi-dimensional arrays问题。

前置知识

在解决这个问题之前,需要了解torch中的boardcast机制,详情可见pytorch官网。
简单来说就是两个tensor在判断是否可以广播时,要进行以下两个步骤:

x=torch.ones(5,1,4,2)
y=torch.ones(3,1,1)
(x+y).size() # Tensor(5, 3, 4, 2)
  1. 将两个tensor按照尾部的维度对齐,即x的最后一个维度2与y的最后一个维度1对齐,x的倒数第二个维度1与y的倒数第二个维度1对齐,直到其中的一个tensor没有维度就停止。在这里就是y的第一个维度3停止,对应于x的第二个维度1.
  2. 在每次对齐中,如果两个维度不一样,且其中一个维度为1,那么就把1变成另维度,比如最后一个维度是1和2,那么就把1变成2(代表着y在最后一个维度(列)上copy了一次);比如3和1,那么就把1变成3(代表着x在第二个维度上copy了两次)
  • PS:如果一对对应的维度数字中,两个数字不同,并且没有一个维度为1,那么就会报错

问题描述

话不多说,直接上问题。

x=torch.ones(5,8,512) # Tensor(5, 8, 512)
y=torch.ones(5,6,2) # Tensor(5, 6, 2)
result = x[torch.arange(x.size(0)), y.permute(2, 1, 0)] # Tensor(2, 6, 5, 512)

通过上述代码可以看到,x的维度为Tensor(5, 8, 512),y的维度为Tensor(5, 6, 2),但是result出来的维度为Tensor(2, 6, 5, 512),这我直接顶不住,不知道为神马。且听下面分析。

理性分析

a = torch.arange(x.size(0)) # Tensor(5, )
aa = y.permute(2, 1, 0) # Tensor(2, 6, 5)
aaa = [torch.arange(x.size(0)), y.permute(2, 1, 0)] # [Tensor(5, ), Tensor(2, 6, 5)]

可以看到,aaa就对应了result中我们要取的Multi-dimensional Index,它是一个list,中间有两个元素,每个元素都是一个Tensor,且维度不一样。这里就有一个补充知识,在做这种Multi-dimensional Index操作的时候,list中的元素要么需要保证维度相同,要么需要保证可以广播,因为在维度不相同时便会进行广播操作,详见NumPy文档。

所以此时就要对a和aa进行广播,得到Tensor(2, 6, 5)这个维度,于是索引就变为了[Tensor(2, 6, 5), Tensor(2, 6, 5)]。

那么在索引是list中嵌套list时,是如何根据下标索引元素的,看下面这个例子。

 y = np.arange(35).reshape(5,7)
y[np.array([0,2,4]), np.array([0,1,2])]
>>> array([ 0, 15, 30])

可以看到,索引的过程是第一个list中的[0]元素与第二个list中的[0]元素对应着取的,并不是n*n,而是n个一一对应的结果,如果是Tensor的话则会自动做一个concat操作。

那么刚刚的问题其实就变成了:

x[[Tensor(2,6,5)], Tensor[2,6,5]] # 其实就相当于选取两个下标,对应于i和j,选2*6*5次

并且由于list中只有2个元素,第0号元素对应x中第0个维度(即5),第1号元素对应x中第1个维度(即8),得到的维度即为Tensor(2, 6, 5, 512)。

如果换成下面这种写法:

result = x[torch.arange(x.size(0)), y.permute(2, 1, 0), y.permute(2, 1, 0)]

则list中有三个元素,广播过后每个元素的维度为Tensor(2,6,5),且分别对应x的三个维度,出来的维度即为Tensor(2,6,5)了。

写在最后

Tensor这种操作,当维度高了过后确实无法用空间去想象了,比较抽象,希望后面能慢慢熟悉,提高代码可读性。

最后

以上就是优雅帽子为你收集整理的深入解析Tensor索引中的Indexing Multi-dimensional arrays问题写在前面前置知识问题描述理性分析写在最后的全部内容,希望文章能够帮你解决深入解析Tensor索引中的Indexing Multi-dimensional arrays问题写在前面前置知识问题描述理性分析写在最后所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(83)

评论列表共有 0 条评论

立即
投稿
返回
顶部