torch.gather
torch.gather(input, dim, index, *, sparse_grad=False, out=None)
沿指定的维收集值。
参数:
-
input (Tensor) –输入张量 -
dim (int) – 要索引的维 -
index (LongTensor) – 要收集的元素的索引 -
sparse_grad (bool, optional) – 如果为True ,关于input 的梯度将是稀疏张量。 -
out (Tensor, optional) –输出张量
对于一维张量,输出由以下公式指定:
out[i] = input[index[i]] # dim= 0
例如:
input_tensor= torch.tensor([1, 2])
index = torch.tensor([0, 0])
input[0]=1
input[1]=2
index[0]=0
index[0]=0
out = torch.gather(input, 0, index)
out[0]=input[index[0]]=input[0]=1
out[1]=input[index[1]]=input[0]=1
对于二维张量,输出由以下公式指定:
out[i][j] = input[index[i][j]][j] # if dim == 0
out[i][j] = input[i][index[i][j]] # if dim == 1
举个栗子:
input_tensor= torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
input[0][0]=1
input[0][1]=2
input[1][0]=3
input[1][1]=4
index[0][0]=0
index[0][1]=0
index[1][0]=1
index[1][1]=0
dim=0:
out = torch.gather(input, 0, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[index[0][0]][0]=input[0][0]=1
out[0][1]=input[index[0][1]][1]=input[0][1]=2
out[1][0]=input[index[1][0]][0]=input[1][0]=3
out[1][1]=input[index[1][1]][1]=input[0][1]=2
dim=1:
out = torch.gather(input, 1, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[0][index[0][0]]=input[0][0]=1
out[0][1]=input[0][index[0][1]]=input[0][0]=1
out[1][0]=input[1][index[1][0]]=input[1][1]=4
out[1][1]=input[1][index[1][1]]=input[1][0]=3
对于三维张量,同理:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
注意:input 和index 必须有相同的维度。out 尺寸和index 相同;input 和index 之间不会广播。
对于d=dim ,可以有index.size(d)< input.size(d) :
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0],[2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 2])
tensor([[2, 1],
[6, 4]])
或index.size(d)> input.size(d) :
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0, 1, 0], [2, 0, 2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 4])
tensor([[2, 1, 2, 1],
[6, 4, 6, 4]])
|