我是靠谱客的博主 任性吐司,最近开发中收集的这篇文章主要介绍pytorch reshape view性能对比 (以及einsum, matmul)reshape和vieweinsum和matmul代码,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

reshape和view

reshape 和 view的具体功能和区别就不介绍了,可以直接查看官网,简单来说,就是reshape会改变各个维度在存储中的物理位置,而view的话,只会改变索引。

那么我们在使用的时候,到底是选哪个呢?

先说结论:差不多~

einsum和matmul

具体用法参见官网,当我们想在维度比较高的tensor上做复杂的矩阵乘法的时候,往往会选择用einsum,因为比较清晰简单,但是如果有办法用matmul的时候,是不是会犹豫两个的性能呢?

结论:在有matmul有broadcast的情况下,einsum更快。简单情况没有测试,但是简单情况直接用matmul比较方便。

如果用for循环来实现einsum可以实现的复杂功能的话,会慢很多,所以千万不要用for loop!!!!

代码

import torch
import time
from collections import defaultdict


def generate_tensor(shape, num=1):
    return [torch.rand(shape).cuda() for _ in range(num)]


def bench_view(datas, shape_to, repeat=1):
    out = []
    for _ in range(repeat):
        res = []
        for d in datas:
            tmp = d.view(shape_to)
            res.append(tmp)
        out.append(torch.cat(res, dim=0)[:2].shape)
    print(len(out))
    
    
def bench_reshape(datas, shape_to, repeat=1):
    out = []
    for _ in range(repeat):
        res = []
        for d in datas:
            tmp = d.reshape(shape_to)
            res.append(tmp)
        out.append(torch.cat(res, dim=0)[:2].shape)
    print(len(out))
    
def bench_einsum(data1, data2, repeat=1):
    out = []
    for _ in range(repeat):
        res = []
        for i in range(len(data1)):
            tmp = torch.einsum('btnf, kfc -> btknc', data1[i], data2[i])
            res.append(tmp)
        out.append(torch.cat(res, dim=0)[:2].shape)
    print(len(out))


def bench_matmul(data1, data2, repeat=1):
    out = []
    for _ in range(repeat):
        res = []
        for i in range(len(data1)):
            # NOTE: if use for loop the speed is too low.
            # tmp = []
            # for j in range(data1[i].shape[1]):
            #     left = data1[i][:,j,...].unsqueeze(dim=1)
            #     tmp.append(torch.matmul(left, data2[i]))
            # tmp = torch.stack(tmp, dim=1)
            
            tmp = torch.matmul(data1[i], data2[i])
            res.append(tmp)
        out.append(torch.cat(res, dim=0)[:2].shape)
    print(len(out), out[0])
    
    

class StopWatch:
    def __init__(self) -> None:
        self.times = defaultdict(list)
    
    def tk(self, key='default'):
        tkn = time.time()
        if key in self.times:
            print(f'key={key}, elapsed:', tkn - self.times[key][-1])
        
        self.times[key].append(tkn)


def test_view_and_reshape():
    
    datas = generate_tensor(shape=(128, 300,30,10), num=10)
    sw = StopWatch()
    
    sw.tk()
    bench_view(datas, shape_to=(100, 64, 60, 30), repeat=1000)
    # bench_reshape(datas, shape_to=(100, 64, 60, 30), repeat=1000)
    sw.tk()
    
    
def test_einsum_matmul():
    data1 = generate_tensor(shape=(128, 1, 300,30), num=5)
    data2 = generate_tensor(shape=(20, 30, 30), num=5)
    sw = StopWatch()
    
    sw.tk()
    bench_einsum(data1, data2, repeat=1000)
    # bench_matmul(data1, data2, repeat=1000)
    sw.tk()
    
if __name__ == '__main__':
    # test_view_and_reshape()
    test_einsum_matmul()
    
    

最后

以上就是任性吐司为你收集整理的pytorch reshape view性能对比 (以及einsum, matmul)reshape和vieweinsum和matmul代码的全部内容,希望文章能够帮你解决pytorch reshape view性能对比 (以及einsum, matmul)reshape和vieweinsum和matmul代码所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(48)

评论列表共有 0 条评论

立即
投稿
返回
顶部