我是靠谱客的博主 呆萌哈密瓜,这篇文章主要介绍PyTorch中的矩阵乘法torch.mm()torch.bmm()torch.matmul()torch.mul()总结,现在分享给大家,希望可以做个参考。

torch.mm()

torch.mm(input, mat2, out=None) → Tensor

矩阵乘法,不进行 broadcast

torch.bmm()

输入1 :(b×n×m) tensor, 输入2:(b×m×p) tensor, 输出:(b×n×p) tensor.
batch 式的矩阵乘法,不broadcast

torch.matmul()

torch.matmul(input, other, out=None) → Tensor

矩阵乘法,有broadcast功能

  1. 如果输入的tensor都是一维,则计算点积:
a = torch.Tensor([1,2,3])
b = torch.Tensor([1,1,1])
torch.matmul(a,b)
# tensor(6.)
  1. 如果输入tensor都是二维矩阵,则计算矩阵乘法:
a = torch.Tensor(3,2)
b = torch.Tensor(2,3)
torch.matmul(a,b).shape
# torch.Size([3, 3])
  1. 如果是[b x m x k]与[b x k x n]形式的矩阵乘法,则进行batched matrix multiply得到[b x m x n]矩阵:
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
# torch.Size([10, 3, 5])

torch.mul()

  1. 如果输入tensor形状相同,则元素相乘:
a = torch.Tensor(2,3)
b = torch.Tensor(2,3)
torch.mul(a,b).shape
# torch.Size([2, 3])
  1. 其他:
a = torch.Tensor([1,2,3])
b = torch.Tensor([[1],[2],[3]])
torch.mul(a,b)
# tensor([[1., 2., 3.],
#
[2., 4., 6.],
#
[3., 6., 9.]])

总结

torch.mm()和torch.bmm()分别是单纯矩阵乘法和batch矩阵乘法,不进行broadcast,比较简单明了。torch.matmul()也是矩阵乘法,但是有broadcast,比较灵活,可以单纯矩阵乘法也可以batch矩阵乘法。
torch.mul()则是元素相乘。

最后

以上就是呆萌哈密瓜最近收集整理的关于PyTorch中的矩阵乘法torch.mm()torch.bmm()torch.matmul()torch.mul()总结的全部内容,更多相关PyTorch中内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(120)

评论列表共有 0 条评论

立即
投稿
返回
顶部