好久没更新博客了,最近一直在忙,既有生活上的也有工作上的。道阻且长啊。
今天来水一文,说一说最近工作上遇到的一个函数:torch.gather() 。
文字理解
我遇到的代码是 NLP 相关的,代码中用 torch.gather() 来将一个 tensor 的 shape 从 (batch_size, seq_length, hidden_size) 转为 (batch_size, labels_length, hidden_size) ,其中 seq_length >= labels_length 。
torch.gather() 的官方解释是
Gathers values along an axis specified by dim.
就是在指定维度上 gather value。那么怎么 gather、gather 哪些 value 呢?这就要看其参数了。
torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:
input (Tensor) – the source tensordim (int) – the axis along which to indexindex (LongTensor) – the indices of elements to gather
所以一句话概括 gather 操作就是:根据 index ,在 input 的 dim 维度上收集 value。
具体来说,input 就是源 tensor,等会我们要在这个 tensor 上执行 gather 操作。如果 input 是一个一维数组,即 flat 列表,那么我们就可以直接根据 index 在 input 上取了,就像正常的列表/数组索引一样。但是由于 input 可能含有多个维度,是 N 维数组,所以我们需要知道在哪个维度上进行 gather,这就是 dim 的作用。
对于 dim 参数,一种更为具体的理解方式是替换法。假设 input 和 index 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可,但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报 IndexError 。所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。如果 dim=0 ,我们就替换索引列表第 0 个值,即 [dim, j, k] ,依此类推。Pytorch 的官方文档的写法其实也是这个意思,但是看这么多个方括号可能会有点懵:
out[i][j][k] = input[index[i][j][k]][j][k]
out[i][j][k] = input[i][index[i][j][k]][k]
out[i][j][k] = input[i][j][index[i][j][k]]
但是可能你还有点迷糊,没关系接着看下面的直观理解部分,然后再回来看这段话,结合着看,相信你很快能明白。
由于我们是按照 index 来取值的,所以最终得到的 tensor 的 shape 也是和 index 一样的,就像我们在列表上按索引取值,得到的输出列表长度和索引相等一样。
直观理解
为便于理解,我们以一个具体例子来说明。我们使用反推法,根据 input 和输出推参数。这应该也是我们平常自己写代码的时候遇到比较多的情况。
假设 input 和我们想要的输出 output 如下:
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> output_tensor
tensor([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[20, 21, 22, 23]]])
即,我们想让 shape 为 (2, 3, 4) 的 input_tensor 变成 shape 为 (2, 2, 4) 的 output_tensor ,丢弃维度 1 的第 2 个元素,即 [ 4, 5, 6, 7] 和 [16, 17, 18, 19] 。
我们应用替换法,重点是找出来 dim 和 index 的值。始终记住 index 和 output_tensor 的 shape 是一样的。
从 output_tensor 的第一个位置开始,由于 output_tensor[0, 0, :] = input_tensor[0, 0, :] ,所以此时 [i, j, k] 是一样的,我们看不出来 dim 应该是多少。
下一行 output_tensor[0, 1, 0] = input_tensor[0, 2, 0] ,这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2, index_tensor[0, 1, 0]=2 。
此时 dim 已经明确。同理,output_tensor[0, 1, 1] = input_tensor[0, 2, 1] ,index_tensor[0, 1, 1]=2 ,依此类推,得到 index_tensor[0, 1, :] = 2 。同时也可以明确 index_tensor[0, 0, :] = 0 。
所以
>>> dim = 0
>>> index_tensor
tensor([[[0, 0, 0, 0],
[2, 2, 2, 2]],
[[0, 0, 0, 0],
[2, 2, 2, 2]]])
简单可描述如下图: 为描述方便,假如我们把输入看作是 6 行,从上到下依次是 0-5。那么从事后诸葛亮的角度讲,输出相当于是把第 1 和第 4 行“抽掉”。如果输出和输入一样,那么原本的 index_tensor 就是如下:
tensor([[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]],
[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]])
“抽掉”后, index_tensor 也相应“抽掉”,那么就得到我们想要的结果了。而且由于这个“抽掉”的操作是在维度 1 上进行的,那么 dim 自然是 1。
numpy.take() 和 tf.gather 貌似也是同样功能,就不细说了。
Reference
END
|