今天在看代码的时候,发现对于嵌套的代码理解不太深刻,这里举个例子加深下印象
先看个简单的例子
a = torch.tensor([1,2,3])
b = torch.tensor([[0,0,0,0,0,0,0,0,0,0,0,0,0],[1,1,1,1,1,1,1,1,1,1,1,1,1]])
我们再来看个复杂的例子。
a
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]])
b
tensor([[0, 0],
[1, 1]])
c
tensor([[3, 3],
[2, 2]])
a[b,:,:]
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]],
[[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]]])
a[b,c,:]
tensor([[[12, 13, 14, 15],
[12, 13, 14, 15]],
[[24, 25, 26, 27],
[24, 25, 26, 27]]])
|