1. 官网解释
o
u
t
[
i
]
[
j
]
[
k
]
=
i
n
p
u
t
[
[
i
n
d
e
x
[
i
]
[
j
]
[
k
]
]
[
j
]
[
k
]
,
?
i
f
?
d
i
m
=
=
0
out[i][j][k] = input[[index[i][j][k]][j][k],\ if \ dim==0
out[i][j][k]=input[[index[i][j][k]][j][k],?if?dim==0 以此类推
意思就是说,假设我要求
o
u
t
[
i
]
[
j
]
out[i][j]
out[i][j]位置的值,如果dim=0,那么
i
i
i就会被
i
n
d
e
x
[
i
]
[
j
]
index[i][j]
index[i][j]代替。而如果dim=1,那么
j
j
j就会被
i
n
d
e
x
[
i
]
[
j
]
index[i][j]
index[i][j]代替。其他类推。
2. 一个例子
假如我有一个这样的tensor
t = torch.tensor([[1,2],[3,4]])
out = torch.gather(t,0,index=torch.tensor([[0,0],[1,0]]))
print(out)
此时dim = 0,所以
o
u
t
[
0
]
[
0
]
=
i
n
p
u
t
[
i
n
d
e
x
[
0
]
[
0
]
]
[
0
]
=
i
n
p
u
t
[
0
]
[
0
]
=
1
out[0][0] = input[index[0][0]][0]=input[0][0]=1
out[0][0]=input[index[0][0]][0]=input[0][0]=1
o
u
t
[
0
]
[
1
]
=
i
n
p
u
t
[
i
n
d
e
x
[
0
]
[
1
]
]
[
1
]
=
i
n
p
u
t
[
0
]
[
1
]
=
2
out[0][1] = input[index[0][1]][1]=input[0][1]=2
out[0][1]=input[index[0][1]][1]=input[0][1]=2
o
u
t
[
1
]
[
0
]
=
i
n
p
u
t
[
i
n
d
e
x
[
1
]
[
0
]
]
[
0
]
=
i
n
p
u
t
[
1
]
[
0
]
=
3
out[1][0] = input[index[1][0]][0]=input[1][0]=3
out[1][0]=input[index[1][0]][0]=input[1][0]=3
o
u
t
[
1
]
[
1
]
=
i
n
p
u
t
[
i
n
d
e
x
[
1
]
[
1
]
]
[
1
]
=
i
n
p
u
t
[
0
]
[
1
]
=
2
out[1][1] = input[index[1][1]][1]=input[0][1]=2
out[1][1]=input[index[1][1]][1]=input[0][1]=2 所以结果应该是
[
[
1
,
2
]
,
[
3
,
2
]
]
[[1,2],[3,2]]
[[1,2],[3,2]]
3. 代码验证
4. 我悟了
其实可以这么看,把值这么一摆:
|