概述
事情的缘由是,坐旁边的学姐有一段代码将两个二维数组相乘。特别的是,既不是点乘,也不是矩阵乘法,而是将各自每一行分别相乘再拼接得到一个三维数组,具体代码大致如下
import torch
a = torch.Tensor(range(6)).reshape(2, 3)
b = torch.Tensor(range(1, 7)).reshape(2, 3)
batch = len(a)
length = len(a[0])
c = torch.zeros(batch, batch, length)
start = time.time()
for i in range(batch):
for j in range(batch):
c[i][j] = a[i] * b[j]
可以看到,实际上是使用a的每一行乘以b的每一行作为c的一个元素,最终c是一个三维数组。
但问题在于,python解释器的执行速度较慢,因此,我做了改进,将a和b分别按照维度1和维度进行扩展,再相乘,结果是一样的。
a_batch = torch.stack([a] * batch, dim=1)
b_batch = torch.stack([b] * batch, dim=0)
c_batch = a_batch * b_batch
速度显然快了很多,但是比较占内存且麻烦。回头看了一个学弟的代码的解决方案,使用了None扩展便捷地解决了这个问题。
a_2 = a[:, None, :]
b_2 = b[None, :, :]
c_2 = a_2 * b_2
实际上,a_2和b_2的维度大小是在None那一维度为1而不是我stack那样的数个。输出各自的维度如下
a.shape torch.Size([2, 3])
a_batch.shape torch.Size([2, 2, 3])
a_2.shape torch.Size([2, 1, 3])
b.shape torch.Size([2, 3])
b_batch.shape torch.Size([2, 2, 3])
b_2.shape torch.Size([1, 2, 3])
疑惑是,之前看的csdn博客,都说过pytorch的矩阵点乘需要两个矩阵的维度相同,然而a_2和b_2为何维度不同也能相乘呢?因此去查询pytorch的官方文档。即查看torch.mul的文档https://pytorch.org/docs/stable/torch.html?highlight=mul#torch.mul,可见
可见,其实在维度不相同时,如果矩阵是可广播的也可以相乘,查看broadcastable的定义 https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
中文翻译即为
如果以下规则成立,则两个张量是“可广播的”:
- 每个张量至少有一个维度。
- 当迭代尺寸时,从尾部尺寸开始,尺寸必须相等,或者其中一个尺寸为1,或者尺寸不存在。
官方举例
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension
# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
所以,还是要多查看官方文档,博客上很多是不全面的。
最后
以上就是暴躁河马为你收集整理的一个关于pytorch的tensor点乘的小问题的全部内容,希望文章能够帮你解决一个关于pytorch的tensor点乘的小问题所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复