我是靠谱客的博主 小巧中心,最近开发中收集的这篇文章主要介绍pytorch einsum 矩阵乘 浅显易懂解释,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

einsum用于矩阵乘法
直接上例子吧
比如

'bhqd, bhkd -> bhqk'

虽然是4维,但是前两维是不变的,先不看,只看后2维,qd, kd -> qk
这是两个矩阵相乘,两个矩阵的shape分别为A=qxd, B=kxd, 得到的结果形状是C =qxk
根据矩阵乘法,我们知道(qxd) x (dxk)结果的形状为qxk,
也就是说上面相当于是AxBT=C

验证一下

energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy.shape',energy.shape)
queries.shape torch.Size([1, 8, 197, 96])
key.shape torch.Size([1, 8, 197, 96])
energy.shape torch.Size([1, 8, 197, 197])

可以看到相当于queries x keysT, 即形状(197x96) x (197x96)T=(197x197)

再看一个

'bhal, bhlv -> bhav'

前两维一样的,不看,只看后两维,仍然看作是矩阵的形状A=axl, B=lxv
矩阵相乘(axl) x (lxv) = (axv),和结果的av相同
所以上面相当于是A与B相乘

验证一下

out = torch.einsum('bhal, bhlv -> bhav', att, values)
print('out.shape',out.shape)
att.shape torch.Size([1, 8, 197, 197])
values.shape torch.Size([1, 8, 197, 96])
out.shape torch.Size([1, 8, 197, 96])

可以看到相当于att x values,即形状(197x197) x (197x96) = (197x96)

最后

以上就是小巧中心为你收集整理的pytorch einsum 矩阵乘 浅显易懂解释的全部内容,希望文章能够帮你解决pytorch einsum 矩阵乘 浅显易懂解释所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(55)

评论列表共有 0 条评论

立即
投稿
返回
顶部