pytorch einsum 矩阵乘 浅显易懂解释
einsum用于矩阵乘法直接上例子吧比如'bhqd, bhkd -> bhqk'虽然是4维,但是前两维是不变的,先不看,只看后2维,qd, kd -> qk这是两个矩阵相乘,两个矩阵的shape分别为A=qxd, B=kxd, 得到的结果形状是C =qxk根据矩阵乘法,我们知道(qxd) x (dxk)结果的形状为qxk,也就是说上面相当于是AxBT=C验证一下energy = torch.einsum('bhqd, bhkd -> bhqk', que