我是靠谱客的博主 冷傲咖啡,最近开发中收集的这篇文章主要介绍pytorch einsum, numpy einsum什么是einsum?为什么用?具体例子,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

什么是einsum?

爱因斯坦求和约定:

https://zhuanlan.zhihu.com/p/101157166
https://en.wikipedia.org/wiki/Einstein_notation

为什么用?

简洁, 强大

具体例子

pytorch和numpy一样的,这里以pytorch为例.

矩阵乘矩阵, C = A x B

eg1. (ij, jk -> ik), 实际上就是

[ C ] i k = ∑ j [ A ] i j × [ B ] j k [C]_{ik} = sum_{j} [A]_{ij} times [B]_{jk} [C]ik=j[A]ij×[B]jk

>>> A = torch.Tensor([[1,1],[2,2]])
>>> B = torch.Tensor([[1,-1],[-1,1]])
>>> torch.einsum("ij, jk -> ik", A, B)
tensor([[0., 0.],
        [0., 0.]])

eg2. (mij, jk -> mik), 相当于在eg1中的m个A分别乘以B,得到m个相应的C

>>> mA = torch.stack([a, a+1, a+2], dim=0)
>>> mA
tensor([[[1., 1.],
         [2., 2.]],

        [[2., 2.],
         [3., 3.]],

        [[3., 3.],
         [4., 4.]]])
>>> mA.shape
torch.Size([3, 2, 2])
>>> B = torch.Tensor([[1,1],[1,1]])
>>> torch.einsum("mij, jk -> mik", mA, B)
tensor([[[2., 2.],
         [4., 4.]],

        [[4., 4.],
         [6., 6.]],

        [[6., 6.],
         [8., 8.]]])

矩阵乘向量

eg3. (mij, j -> mi).

>>> b = torch.Tensor([1,-1])
>>> torch.einsum("mij, j -> mi", mA, b)
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

eg4. (bnf, n -> bf).

相当于对n个f维向量加权求和.

>>> 
a = torch.arange(12).reshape(3,2,2).float()
print(a)
weighted = 2 * torch.ones(2)
print(weighted)
b = torch.einsum('bnf, n -> bf', a, weighted)
print(b)
print(b.shape)
>>>
output:
tensor([[[ 0.,  1.],
         [ 2.,  3.]],

        [[ 4.,  5.],
         [ 6.,  7.]],

        [[ 8.,  9.],
         [10., 11.]]])
tensor([2., 2.])
tensor([[ 4.,  8.],
        [20., 24.],
        [36., 40.]])
torch.Size([3, 2])

高维Tensor乘法

eg5. (bnft, knm -> bkmft).

看起来有点复杂,其实很简单:
bnft相当于b个mA, knm 相当于k个B,那么bkmft,相当于b个mA分别乘以k个B。

但是你会发现,mA是三维的(nft),B是二维的(nm),这其实等同于eg2。

观察最后的结果(bkmft)里,bk就是b个mA乘以k个B,所以一共有b*k个tensor运算结果,存到了b维和k维。

剩下的mft = nft *nm,实际上要让这个等式成立,必须要变成mft = mn * nft, 即nm转置为mn,然后 分别乘以t个nf,即,t个矩阵乘法: m n × n f = m f mn times nf = mf mn×nf=mf

>>> mA.shape
torch.Size([3, 2, 2])
>>> bmA = torch.stack([mA,-mA], dim=0)
>>> bmA.shape
torch.Size([2, 3, 2, 2])
>>> B = torch.Tensor([[1,1],[1,-1]])
>>> kB = torch.stack([B,B,B],dim=1)
>>> kB.shape
torch.Size([2, 3, 2])
>>> torch.einsum("bnft, lnm -> blmft", bmA, lB).shape
torch.Size([2, 2, 2, 2, 2])
>>> torch.mm(kB[1].transpose(0,1), bmA[0,:,:,0])
tensor([[ 6.,  9.],
        [-6., -9.]])
>>> torch.einsum("bnft, knm -> bkmft", bmA, kB)[0,1,:,:,0]
tensor([[ 6.,  9.],
        [-6., -9.]])

从最后两个code可以看到,用kB中的第二个(index=1)B,进行转置后,乘以第一个mA中的第一个nf,即(t这个维度取第一个,index=0)的结果,等同于einsum结果中对应位置的矩阵(即b=0,k=1,t=0)。

eg6. (bkmft, k -> bmft). 高维Tensor乘以向量。

相当于对左边的 b ∗ k b*k bk m f t mft mft 进行了一个加权求和(权重会乘以每一个元素),权重就是右边的k维向量。

>>> C = torch.einsum("bnft, knm -> bkmft", bmA, kB)
>>> b = torch.Tensor([1, 2]) # 权重第一个mft就是1,第二个mft就是2.
>>> torch.einsum("bkmft, k -> bmft", C, b)
tensor([[[[ 18.,  18.],
          [ 27.,  27.]],

         [[ -6.,  -6.],
          [ -9.,  -9.]]],


        [[[-18., -18.],
          [-27., -27.]],

         [[  6.,   6.],
          [  9.,   9.]]]])

最后

以上就是冷傲咖啡为你收集整理的pytorch einsum, numpy einsum什么是einsum?为什么用?具体例子的全部内容,希望文章能够帮你解决pytorch einsum, numpy einsum什么是einsum?为什么用?具体例子所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(44)

评论列表共有 0 条评论

立即
投稿
返回
顶部