概述
import torch
import numpy as np
a = torch.arange(9).reshape(3, 3)
提取矩阵对角线元素
out = torch.einsum('ii->i', a) # tensor([0, 4, 8])
矩阵转置
out = torch.einsum('ij->ji', a)
out = torch.einsum('...ij->...ji', a) # 高维矩阵最后两维转置
reduce sum
out = torch.einsum('ij->', a) # tensor(36)
矩阵按列求和
out = torch.einsum('ki->i', a)
矩阵向量乘法
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
out = torch.einsum('ik,k->i', a, b)
out = torch.einsum('ik,k', a, b) # 箭头右侧符号可以不写,按规则默认推理。
矩阵乘法
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
out = torch.einsum('ik,kj->ij', a, b)
out = torch.einsum('ik,kj', a, b)
向量内积
a = torch.arange(3)
b = torch.arange(3, 6)
out = torch.einsum('i,i->', a, b)
out = torch.einsum('i,i', a, b)
矩阵元素对应相乘并求reduce sum
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
out = torch.einsum('ij,ij->', a, b)
向量外积
a = torch.arange(3)
b = torch.arange(3,7)
out = torch.einsum('i,j->ij', a, b)
batch矩阵乘法
a = torch.randn(2,3,5)
b = torch.randn(2,5,4)
out = torch.einsum('bik,bkj->bij', a, b)
张量收缩
tensor contraction, 用不上,暂时看不懂。
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
out = torch.einsum('pqrs,tuqvr->pstuv', a, b)
双线性变换
bilinear transformation. Applies a bilinear transformation to the incoming data.
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
out = torch.einsum('ik,jkl,il->ij', a, b, c)
最后
以上就是疯狂老鼠为你收集整理的有趣的torch.einsum的全部内容,希望文章能够帮你解决有趣的torch.einsum所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复