我是靠谱客的博主 整齐哈密瓜,最近开发中收集的这篇文章主要介绍【笔记】torch 乘法总结【笔记】torch 乘法总结,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

【笔记】torch 乘法总结

一、乘号(*) 和 torch.mul()

element-wise 即对应元素相乘

例子:

>>> a = torch.randn(2,3)
>>> b = torch.randn(2,1)
>>> res = a * b
>>> res
tensor([[-0.9672, -0.1052,
0.1392],
[-0.8552,
0.8967, -0.6433]])

特别地,如果是( 2 × 1 × 3 2times 1times 3 2×1×3)和( 2 × 4 × 3 2times 4times 3 2×4×3)这种情况,也可以相乘,结果是( 2 × 4 × 3 2times 4times 3 2×4×3)。相当于用前一个tensor沿着后一个tensor的第1维度expand(broadcast机制)。

例子:

>>> a = torch.randn(2,3)
>>> b = torch.randn(2,1)
>>> res1 = a * b
>>> res1
tensor([[ 0.9199,
0.4053,
0.0789],
[ 2.1330, -0.5653,
0.4760]])
>>> res2 = torch.mul(a,b)
>>> res2
tensor([[ 0.9199,
0.4053,
0.0789],
[ 2.1330, -0.5653,
0.4760]])	# torch.mul() 和 * 效果相同
>>> res3 = a * b.expand(2,3)
>>> res3
tensor([[ 0.9199,
0.4053,
0.0789],	# 两个tensor维度不一致时,
[ 2.1330, -0.5653,
0.4760]])	# 会自动进行expand

二、torch.mm() 与 torch.bmm()

矩阵乘法。

  • torch.mm(mat1, mat2) m a t 1 mat1 mat1 m a t 2 mat2 mat2 进行矩阵乘法,要求输入只能是2维矩阵。

  • torch.bmm(mat1, mat2)专门进行batch形式的矩阵乘法。要求(1)输入只能是3维矩阵( b a t c h , d 1 , d 2 batch, d_1, d_2 batch,d1,d2);(2)第0维度相同。

三、torch.matmul()

矩阵乘法,支持broadcast。

下面是torch.matmul(mat1, mat2)的适用情形:

  • m a t 1 mat1 mat1 m a t 2 mat2 mat2 都是一维向量,则结果返回的是一个标量

  • (标准的二维矩阵乘法)若 m a t 1 ∈ R m × n mat1 in mathbb{R}^{m times n} mat1Rm×n, m a t 2 ∈ R n × d mat2 in mathbb{R}^{n times d} mat2Rn×d,则返回两个矩阵乘积 o u t p u t ∈ R m × d outputin mathbb{R}^{mtimes d} outputRm×d

  • m a t 1 ∈ R n mat1 in mathbb{R}^{n} mat1Rn 是一维向量,则会为 m a t 1 mat1 mat1 添加一个维度,变成 m a t 1 ∈ R 1 × n mat1 in mathbb{R}^{1 times n} mat1R1×n,然后与矩阵2 m a t 2 ∈ R n × d mat2 in mathbb{R}^{n times d} mat2Rn×d 进行二维矩阵乘法。并在最后的结果中移除添加的那一维度 => R d mathbb{R}^{d} Rd

  • 若至少有一个参与运算的矩阵的维度大于2,则进行batch矩阵乘法

    这里会把矩阵的后两个维度视作矩阵维度(matrix dimensions),参与矩阵运算,其他维度视作batch维度,进行broadcast处理。

    两个例子:

    (1) mat1 是一个 size 为 ( j × 1 × n × n j times 1 times n times n j×1×n×n) 的 tensor ,mat2 是一个 size 为 ( k × n × n k times n times n k×n×n) 的 tensor, out 将会是一个 ( j × k × n × n ) j times k times n times n) j×k×n×n) 的 tensor. 这里 ( n × n n times n n×n) 部分是矩阵维度,( k k k) 和 ( j × k j times k j×k) 是 batch 维度。

    (2) mat1 是一个 size 为 ( j × 1 × n × m j times 1 times n times m j×1×n×m) 的 tensor , mat2 是一个 size 为 ( k × m × p k times m times p k×m×p) 的 tensor, 这些输入的 tensors 支持 broadcasting 机制,即使最后两个矩阵维度是不同的。 输出out 将会是一个 size 为 ( j × k × n × p j times k times n times p j×k×n×p) 的 tensor.

来源:torch.matmul — PyTorch 1.8.1 documentation

最后

以上就是整齐哈密瓜为你收集整理的【笔记】torch 乘法总结【笔记】torch 乘法总结的全部内容,希望文章能够帮你解决【笔记】torch 乘法总结【笔记】torch 乘法总结所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(41)

评论列表共有 0 条评论

立即
投稿
返回
顶部