概述
【笔记】torch 乘法总结
一、乘号(*) 和 torch.mul()
element-wise 即对应元素相乘
例子:
>>> a = torch.randn(2,3)
>>> b = torch.randn(2,1)
>>> res = a * b
>>> res
tensor([[-0.9672, -0.1052,
0.1392],
[-0.8552,
0.8967, -0.6433]])
特别地,如果是( 2 × 1 × 3 2times 1times 3 2×1×3)和( 2 × 4 × 3 2times 4times 3 2×4×3)这种情况,也可以相乘,结果是( 2 × 4 × 3 2times 4times 3 2×4×3)。相当于用前一个tensor沿着后一个tensor的第1维度expand(broadcast机制)。
例子:
>>> a = torch.randn(2,3)
>>> b = torch.randn(2,1)
>>> res1 = a * b
>>> res1
tensor([[ 0.9199,
0.4053,
0.0789],
[ 2.1330, -0.5653,
0.4760]])
>>> res2 = torch.mul(a,b)
>>> res2
tensor([[ 0.9199,
0.4053,
0.0789],
[ 2.1330, -0.5653,
0.4760]]) # torch.mul() 和 * 效果相同
>>> res3 = a * b.expand(2,3)
>>> res3
tensor([[ 0.9199,
0.4053,
0.0789], # 两个tensor维度不一致时,
[ 2.1330, -0.5653,
0.4760]]) # 会自动进行expand
二、torch.mm() 与 torch.bmm()
矩阵乘法。
-
torch.mm(mat1, mat2)
对 m a t 1 mat1 mat1 和 m a t 2 mat2 mat2 进行矩阵乘法,要求输入只能是2维矩阵。 -
torch.bmm(mat1, mat2)
专门进行batch形式的矩阵乘法。要求(1)输入只能是3维矩阵( b a t c h , d 1 , d 2 batch, d_1, d_2 batch,d1,d2);(2)第0维度相同。
三、torch.matmul()
矩阵乘法,支持broadcast。
下面是torch.matmul(mat1, mat2)
的适用情形:
-
若 m a t 1 mat1 mat1 , m a t 2 mat2 mat2 都是一维向量,则结果返回的是一个标量。
-
(标准的二维矩阵乘法)若 m a t 1 ∈ R m × n mat1 in mathbb{R}^{m times n} mat1∈Rm×n, m a t 2 ∈ R n × d mat2 in mathbb{R}^{n times d} mat2∈Rn×d,则返回两个矩阵乘积 o u t p u t ∈ R m × d outputin mathbb{R}^{mtimes d} output∈Rm×d。
-
若 m a t 1 ∈ R n mat1 in mathbb{R}^{n} mat1∈Rn 是一维向量,则会为 m a t 1 mat1 mat1 添加一个维度,变成 m a t 1 ∈ R 1 × n mat1 in mathbb{R}^{1 times n} mat1∈R1×n,然后与矩阵2 m a t 2 ∈ R n × d mat2 in mathbb{R}^{n times d} mat2∈Rn×d 进行二维矩阵乘法。并在最后的结果中移除添加的那一维度 => R d mathbb{R}^{d} Rd 。
-
若至少有一个参与运算的矩阵的维度大于2,则进行batch矩阵乘法。
这里会把矩阵的后两个维度视作矩阵维度(matrix dimensions),参与矩阵运算,其他维度视作batch维度,进行broadcast处理。
两个例子:
(1)
mat1
是一个 size 为 ( j × 1 × n × n j times 1 times n times n j×1×n×n) 的 tensor ,mat2
是一个 size 为 ( k × n × n k times n times n k×n×n) 的 tensor,out
将会是一个 ( j × k × n × n ) j times k times n times n) j×k×n×n) 的 tensor. 这里 ( n × n n times n n×n) 部分是矩阵维度,( k k k) 和 ( j × k j times k j×k) 是 batch 维度。(2)
mat1
是一个 size 为 ( j × 1 × n × m j times 1 times n times m j×1×n×m) 的 tensor ,mat2
是一个 size 为 ( k × m × p k times m times p k×m×p) 的 tensor, 这些输入的 tensors 支持 broadcasting 机制,即使最后两个矩阵维度是不同的。 输出out
将会是一个 size 为 ( j × k × n × p j times k times n times p j×k×n×p) 的 tensor.
来源:torch.matmul — PyTorch 1.8.1 documentation
最后
以上就是整齐哈密瓜为你收集整理的【笔记】torch 乘法总结【笔记】torch 乘法总结的全部内容,希望文章能够帮你解决【笔记】torch 乘法总结【笔记】torch 乘法总结所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复