tensor.gather(dim,index)
官方讲解:https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather
知乎讲解:https://zhuanlan.zhihu.com/p/352877584
理解需要15min
1.例子
tensor_0 = [[3,4,5],
[6,7,8],
[9,10,11]]
index = [[2,1,0]]
tensor_1 = tensor_0.gather(dim=0,index=index)
tensor_1 = [[9,7,5]]
tensor_1 = tensor_0.gather(dim=1,index=index)
tensor_1 = [[5,4,3]]
2.个人理解
tensor_1 = tensor_0.gather(dim,index)
首先, 这句话的含义是,利用index的索引,选择tensor_0中的某些元素组成tensor_1。
这里的tensor_1的shape与index保持一致,与tensor_0没关系。
其次,dim=0表示:
tensor_1[0,0] = tensor_0[index[0,0],0]
tensor_1[0,1] = tensor_0[index[0,1],1]
...
即,tensor_1的第0列一定是从tensor_0的第0列里取值,取哪行看index;第二列同理,也是从第二列取值,取哪个看index。所以给tensor_1赋值的时候,按照列的顺序,把tensor_0第0列的某个值给tensor_1的第0列,把tensor_0的第一列的某个值给tensor_1的第1列…即当dim=0的时候,index按照列索引赋值
PS:看行和列哪一个维度没有index,就是按照哪一维索引。dim=0时列维数是0 1 2(确定的),所以是按列索引。
同理,dim=1表示:
tensor_1[0,0] = tensor_0[0,index[0,0]]
tensor_1[0,1] = tensor_0[0,index[0,1]]
...
即,tensor_1的第0行一定是从tensor_0的第0行里取值,取哪列看index;第二行同理,也是从第二行取值,取哪个看index。所以给tensor_1赋值的时候,按照行的顺序,把tensor_0第0行的某个值给tensor_1的第0行,把tensor_0的第一行的某个值给tensor_1的第1行…即当dim=1的时候,index按照行索引赋值
3.练习
练习1
tensor_0 = [[3,4,5],
[6,7,8],
[9,10,11]]
index = [[2],
[1],
[0]]
tensor_1=tensor_0.gather(dim=1,index=index)
output:
tensor_1:[[5],
[7],
[9]]
练习2
tensor_0 = [[3,4,5],
[6,7,8],
[9,10,11]]
index = [[0,2],
[1,2]]
tensor_1=tensor_0.gather(dim=1,index=index)
output:
tensor_1:[[3,5],
[7,8]]
4.结论
tensor_1 = tensor_0.gather(dim,index)
1.tensor_1的维度与index一致,与tensor_0无关。
2.当dim=0时,index按照列索引赋值。
3.当dim=1时,index按照行索引赋值。
4.看行和列哪一个维度没有index,就是按照哪一维索引。dim=0时列维数是0 1 2(确定的),所以是按列索引。
|