我是靠谱客的博主 疯狂老鼠,最近开发中收集的这篇文章主要介绍有趣的torch.einsum,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

import torch
import numpy as np
a = torch.arange(9).reshape(3, 3)

提取矩阵对角线元素

out = torch.einsum('ii->i', a)	# tensor([0, 4, 8])

矩阵转置

out = torch.einsum('ij->ji', a)
out = torch.einsum('...ij->...ji', a) # 高维矩阵最后两维转置

reduce sum

out = torch.einsum('ij->', a)	# tensor(36)

矩阵按列求和

out = torch.einsum('ki->i', a)

矩阵向量乘法

a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
out = torch.einsum('ik,k->i', a, b)
out = torch.einsum('ik,k', a, b)	# 箭头右侧符号可以不写,按规则默认推理。

矩阵乘法

a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
out = torch.einsum('ik,kj->ij', a, b)
out = torch.einsum('ik,kj', a, b)

向量内积

a = torch.arange(3)
b = torch.arange(3, 6)
out = torch.einsum('i,i->', a, b)
out = torch.einsum('i,i', a, b)

矩阵元素对应相乘并求reduce sum

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
out = torch.einsum('ij,ij->', a, b)

向量外积

a = torch.arange(3)
b = torch.arange(3,7)
out = torch.einsum('i,j->ij', a, b)

batch矩阵乘法

a = torch.randn(2,3,5)
b = torch.randn(2,5,4)
out = torch.einsum('bik,bkj->bij', a, b)

张量收缩

tensor contraction, 用不上,暂时看不懂。

a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
out = torch.einsum('pqrs,tuqvr->pstuv', a, b)

双线性变换

bilinear transformation. Applies a bilinear transformation to the incoming data.

a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
out = torch.einsum('ik,jkl,il->ij', a, b, c)

最后

以上就是疯狂老鼠为你收集整理的有趣的torch.einsum的全部内容,希望文章能够帮你解决有趣的torch.einsum所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(40)

评论列表共有 0 条评论

立即
投稿
返回
顶部