我是靠谱客的博主 暴躁河马,最近开发中收集的这篇文章主要介绍一个关于pytorch的tensor点乘的小问题,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

事情的缘由是,坐旁边的学姐有一段代码将两个二维数组相乘。特别的是,既不是点乘,也不是矩阵乘法,而是将各自每一行分别相乘再拼接得到一个三维数组,具体代码大致如下

import torch
a = torch.Tensor(range(6)).reshape(2, 3)
b = torch.Tensor(range(1, 7)).reshape(2, 3)
batch = len(a)
length = len(a[0])
c = torch.zeros(batch, batch, length)
start = time.time()
for i in range(batch):
    for j in range(batch):
        c[i][j] = a[i] * b[j]

可以看到,实际上是使用a的每一行乘以b的每一行作为c的一个元素,最终c是一个三维数组。

但问题在于,python解释器的执行速度较慢,因此,我做了改进,将a和b分别按照维度1和维度进行扩展,再相乘,结果是一样的。

a_batch = torch.stack([a] * batch, dim=1)
b_batch = torch.stack([b] * batch, dim=0)
c_batch = a_batch * b_batch

速度显然快了很多,但是比较占内存且麻烦。回头看了一个学弟的代码的解决方案,使用了None扩展便捷地解决了这个问题。

a_2 = a[:, None, :]
b_2 = b[None, :, :]
c_2 = a_2 * b_2

实际上,a_2和b_2的维度大小是在None那一维度为1而不是我stack那样的数个。输出各自的维度如下

a.shape  torch.Size([2, 3])
a_batch.shape  torch.Size([2, 2, 3])
a_2.shape  torch.Size([2, 1, 3])
b.shape  torch.Size([2, 3])
b_batch.shape  torch.Size([2, 2, 3])
b_2.shape  torch.Size([1, 2, 3])

疑惑是,之前看的csdn博客,都说过pytorch的矩阵点乘需要两个矩阵的维度相同,然而a_2和b_2为何维度不同也能相乘呢?因此去查询pytorch的官方文档。即查看torch.mul的文档https://pytorch.org/docs/stable/torch.html?highlight=mul#torch.mul,可见

可见,其实在维度不相同时,如果矩阵是可广播的也可以相乘,查看broadcastable的定义 https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

中文翻译即为

如果以下规则成立,则两个张量是“可广播的”:

  1. 每个张量至少有一个维度
  2. 当迭代尺寸时,从尾部尺寸开始,尺寸必须相等,或者其中一个尺寸为1,或者尺寸不存在

官方举例

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

所以,还是要多查看官方文档,博客上很多是不全面的。

最后

以上就是暴躁河马为你收集整理的一个关于pytorch的tensor点乘的小问题的全部内容,希望文章能够帮你解决一个关于pytorch的tensor点乘的小问题所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(47)

评论列表共有 0 条评论

立即
投稿
返回
顶部