Pytorch的torch.max()使用讲解
output = torch.max(input, dim)input输入的是一个tensordim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值实例import torchimport numpy as npimport matplotlib.pyplot as pltx = torch.randn(3,3)print(x)max_value,index = torch.max(x,dim=1) #返回的是两个值,一个是每一行最大值的tensor组,另一个是