标致世界

文章
4
资源
0
加入时间
4年1月25天

Pytorch的torch.max()使用讲解

output = torch.max(input, dim)input输入的是一个tensordim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值实例import torchimport numpy as npimport matplotlib.pyplot as pltx = torch.randn(3,3)print(x)max_value,index = torch.max(x,dim=1) #返回的是两个值,一个是每一行最大值的tensor组,另一个是