我是靠谱客的博主 任性吐司,最近开发中收集的这篇文章主要介绍pytorch reshape view性能对比 (以及einsum, matmul)reshape和vieweinsum和matmul代码,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
reshape和view
reshape 和 view的具体功能和区别就不介绍了,可以直接查看官网,简单来说,就是reshape会改变各个维度在存储中的物理位置,而view的话,只会改变索引。
那么我们在使用的时候,到底是选哪个呢?
先说结论:差不多~
einsum和matmul
具体用法参见官网,当我们想在维度比较高的tensor上做复杂的矩阵乘法的时候,往往会选择用einsum,因为比较清晰简单,但是如果有办法用matmul的时候,是不是会犹豫两个的性能呢?
结论:在有matmul有broadcast的情况下,einsum更快。简单情况没有测试,但是简单情况直接用matmul比较方便。
如果用for循环来实现einsum可以实现的复杂功能的话,会慢很多,所以千万不要用for loop!!!!
代码
import torch
import time
from collections import defaultdict
def generate_tensor(shape, num=1):
return [torch.rand(shape).cuda() for _ in range(num)]
def bench_view(datas, shape_to, repeat=1):
out = []
for _ in range(repeat):
res = []
for d in datas:
tmp = d.view(shape_to)
res.append(tmp)
out.append(torch.cat(res, dim=0)[:2].shape)
print(len(out))
def bench_reshape(datas, shape_to, repeat=1):
out = []
for _ in range(repeat):
res = []
for d in datas:
tmp = d.reshape(shape_to)
res.append(tmp)
out.append(torch.cat(res, dim=0)[:2].shape)
print(len(out))
def bench_einsum(data1, data2, repeat=1):
out = []
for _ in range(repeat):
res = []
for i in range(len(data1)):
tmp = torch.einsum('btnf, kfc -> btknc', data1[i], data2[i])
res.append(tmp)
out.append(torch.cat(res, dim=0)[:2].shape)
print(len(out))
def bench_matmul(data1, data2, repeat=1):
out = []
for _ in range(repeat):
res = []
for i in range(len(data1)):
# NOTE: if use for loop the speed is too low.
# tmp = []
# for j in range(data1[i].shape[1]):
# left = data1[i][:,j,...].unsqueeze(dim=1)
# tmp.append(torch.matmul(left, data2[i]))
# tmp = torch.stack(tmp, dim=1)
tmp = torch.matmul(data1[i], data2[i])
res.append(tmp)
out.append(torch.cat(res, dim=0)[:2].shape)
print(len(out), out[0])
class StopWatch:
def __init__(self) -> None:
self.times = defaultdict(list)
def tk(self, key='default'):
tkn = time.time()
if key in self.times:
print(f'key={key}, elapsed:', tkn - self.times[key][-1])
self.times[key].append(tkn)
def test_view_and_reshape():
datas = generate_tensor(shape=(128, 300,30,10), num=10)
sw = StopWatch()
sw.tk()
bench_view(datas, shape_to=(100, 64, 60, 30), repeat=1000)
# bench_reshape(datas, shape_to=(100, 64, 60, 30), repeat=1000)
sw.tk()
def test_einsum_matmul():
data1 = generate_tensor(shape=(128, 1, 300,30), num=5)
data2 = generate_tensor(shape=(20, 30, 30), num=5)
sw = StopWatch()
sw.tk()
bench_einsum(data1, data2, repeat=1000)
# bench_matmul(data1, data2, repeat=1000)
sw.tk()
if __name__ == '__main__':
# test_view_and_reshape()
test_einsum_matmul()
最后
以上就是任性吐司为你收集整理的pytorch reshape view性能对比 (以及einsum, matmul)reshape和vieweinsum和matmul代码的全部内容,希望文章能够帮你解决pytorch reshape view性能对比 (以及einsum, matmul)reshape和vieweinsum和matmul代码所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复