概述
einsum用于矩阵乘法
直接上例子吧
比如
'bhqd, bhkd -> bhqk'
虽然是4维,但是前两维是不变的,先不看,只看后2维,qd, kd -> qk
这是两个矩阵相乘,两个矩阵的shape分别为A=qxd, B=kxd, 得到的结果形状是C =qxk
根据矩阵乘法,我们知道(qxd) x (dxk)结果的形状为qxk,
也就是说上面相当于是AxBT=C
验证一下
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy.shape',energy.shape)
queries.shape torch.Size([1, 8, 197, 96])
key.shape torch.Size([1, 8, 197, 96])
energy.shape torch.Size([1, 8, 197, 197])
可以看到相当于queries x keysT, 即形状(197x96) x (197x96)T=(197x197)
再看一个
'bhal, bhlv -> bhav'
前两维一样的,不看,只看后两维,仍然看作是矩阵的形状A=axl, B=lxv
矩阵相乘(axl) x (lxv) = (axv),和结果的av相同
所以上面相当于是A与B相乘
验证一下
out = torch.einsum('bhal, bhlv -> bhav', att, values)
print('out.shape',out.shape)
att.shape torch.Size([1, 8, 197, 197])
values.shape torch.Size([1, 8, 197, 96])
out.shape torch.Size([1, 8, 197, 96])
可以看到相当于att x values,即形状(197x197) x (197x96) = (197x96)
最后
以上就是小巧中心为你收集整理的pytorch einsum 矩阵乘 浅显易懂解释的全部内容,希望文章能够帮你解决pytorch einsum 矩阵乘 浅显易懂解释所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复