1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
import torch
model = torch.nn.Embedding(9, 10)
input_embed = torch.tensor([[2, 3, 4, 5, 0], [6, 7, 8, 0, 0]]) word_index = torch.tensor([[0, 1, 2, 3, 0], [0, 1, 2, 0, 0, ]]) out = model(input_embed) print(out, out.shape) result = torch.gather(out, dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, out.size(-1)))
print(result, result.shape)
""" tensor([[[-0.1290, -0.8854, -0.0281, -1.0047, -0.6565, 0.2380, 1.6062, -1.0218, 0.9187, -0.5830], [ 1.7823, 0.1437, 0.8367, 0.3261, 0.0991, -0.8338, 1.5731, 2.6733, 0.2048, -0.4198], [ 0.9965, 2.6325, 1.1463, -0.3047, 0.7547, -1.9135, -1.9450, 0.1363, 1.5608, 1.0028], [-0.3929, 0.3888, 0.3454, -0.5054, -0.0680, -0.3803, 1.2884, -1.1461, -0.3259, 0.6795], [ 0.1853, -0.8628, -1.4179, 1.4490, 0.6766, 0.7106, -0.6956, -0.2125, 0.2669, -0.0373]],
[[-0.9646, 0.6840, 1.7190, 1.2912, -1.1019, 0.6682, -0.0500, -0.2393, 0.8611, 1.2914], [ 0.6630, 0.7863, -0.2253, -1.5720, -0.4309, -2.0466, -1.0762, 0.5243, -0.3297, -0.0142], [ 0.0903, -1.0030, 0.1973, 0.9981, 1.2901, -0.5555, -0.2912, -0.6930, -0.1299, -0.9054], [ 0.1853, -0.8628, -1.4179, 1.4490, 0.6766, 0.7106, -0.6956, -0.2125, 0.2669, -0.0373], [ 0.1853, -0.8628, -1.4179, 1.4490, 0.6766, 0.7106, -0.6956, -0.2125, 0.2669, -0.0373]]], grad_fn=<EmbeddingBackward>) torch.Size([2, 5, 10]) tensor([[[-0.1290, -0.8854, -0.0281, -1.0047, -0.6565, 0.2380, 1.6062, -1.0218, 0.9187, -0.5830], [ 1.7823, 0.1437, 0.8367, 0.3261, 0.0991, -0.8338, 1.5731, 2.6733, 0.2048, -0.4198], [ 0.9965, 2.6325, 1.1463, -0.3047, 0.7547, -1.9135, -1.9450, 0.1363, 1.5608, 1.0028], [-0.3929, 0.3888, 0.3454, -0.5054, -0.0680, -0.3803, 1.2884, -1.1461, -0.3259, 0.6795], [-0.1290, -0.8854, -0.0281, -1.0047, -0.6565, 0.2380, 1.6062, -1.0218, 0.9187, -0.5830]],
[[-0.9646, 0.6840, 1.7190, 1.2912, -1.1019, 0.6682, -0.0500, -0.2393, 0.8611, 1.2914], [ 0.6630, 0.7863, -0.2253, -1.5720, -0.4309, -2.0466, -1.0762, 0.5243, -0.3297, -0.0142], [ 0.0903, -1.0030, 0.1973, 0.9981, 1.2901, -0.5555, -0.2912, -0.6930, -0.1299, -0.9054], [-0.9646, 0.6840, 1.7190, 1.2912, -1.1019, 0.6682, -0.0500, -0.2393, 0.8611, 1.2914], [-0.9646, 0.6840, 1.7190, 1.2912, -1.1019, 0.6682, -0.0500, -0.2393, 0.8611, 1.2914]]], grad_fn=<GatherBackward>) torch.Size([2, 5, 10])
Process finished with exit code 0
"""
|
发表评论 取消回复