官方介绍如下:链接?功能就是从原tensor中获取指定dim和指定index的数据。
直接例子说明:
import torch
a = torch.arange(16, 32).view(4, 4)
print(a)
输出:
tensor([[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]])
1. 按列索引,index为行向量
index = torch.tensor([[2, 1, 0, 3]])
b= a.gather(0, index)
print(b)
输出:
tensor([[24, 21, 18, 31]])
dim=0代表按列索引,那么index第一个元素“2”的含义为在a中其所在列(即第0列)的第2个元素。同理,index第二个元素“1”的含义为在a中其所在列(即第1列)的第1个元素;index第三个元素“0”的含义为在a中其所在列(即第3列)的第0个元素;index第三个元素“3”的含义为在a中其所在列(即第4列)的第3个元素;
2. 按行索引,index为行向量
index = torch.tensor([[2, 1, 0, 3]])
b= a.gather(1, index)
print(b)
输出:
tensor([[18, 17, 16, 19]])
dim=1代表按行索引,而[[2, 1, 0, 3]]本身就是行向量,故“2”“1”“0”“3”都代表的是a中第0行的对应列数的元素,将它们拿出来,即组成b
3. 按列索引,index为列向量
index = torch.tensor([[2, 1, 0, 3]])
b= a.gather(1, index.t()) #index进行转置
print(b)
输出:
tensor([[18],
[21],
[24],
[31]])
dim=0代表按行索引,那么index第一个元素“2”的含义为在a中其所在行(即第0行)的第2个元素。同理,index第二个元素“1”的含义为在a中其所在行(即第1行)的第1个元素;index第三个元素“0”的含义为在a中其所在行(即第3行)的第0个元素;index第三个元素“3”的含义为在a中其所在行(即第4行)的第3个元素
结语:以b?= a.gather(dim=0, index)为例,首先确定b的维度与index维度一致(index维度可以是任意的维度,不要受限于a),即b的维度为(1,4)
|