Pytorch学习 (二十六)---- torch.scatter的使用总说
总说一个非常有用的函数,主要是用于“group index”的操作。from torch_scatter import scatterimport torchsrc = (torch.rand(2, 6, 2)*4).int()index = torch.tensor([0, 1, 0, 1, 2, 1])# Broadcasting in the first and last dim.out = scatter(src, index, dim=1, reduce="sum"