Pytorch_flatten()函数
Talk is cheap, show me the code.import torcht = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]])print(torch.flatten(t)