危机老鼠

文章
4
资源
0
加入时间
3年0月27天

PyTorch 矩阵乘法总结和运算符重载

1. 二维矩阵乘法 torch.mm()也就是最基本的矩阵乘法,需要满足对应维度的要求,否则报错torch.mm(mat1, mat2, out=None)mat1∈Rm×n\in \mathbb{R}^{m \times n}∈Rm×n,mat2 ∈Rn×d\in \mathbb{R}^{n \times d}∈Rn×d,输出 out ∈Rm×d\in \mathbb{R}^{m \times d}∈Rm×d。2. 三维带batch的矩阵乘法 torch.bmm()torch.bmm(bmat