IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 理解 PyTorch 中的 gather 函数 -> 正文阅读

[人工智能]理解 PyTorch 中的 gather 函数

好久没更新博客了,最近一直在忙,既有生活上的也有工作上的。道阻且长啊。

今天来水一文,说一说最近工作上遇到的一个函数:torch.gather()

文字理解

我遇到的代码是 NLP 相关的,代码中用 torch.gather() 来将一个 tensor 的 shape 从 (batch_size, seq_length, hidden_size) 转为 (batch_size, labels_length, hidden_size) ,其中 seq_length >= labels_length

torch.gather() 的官方解释是

Gathers values along an axis specified by dim.

就是在指定维度上 gather value。那么怎么 gather、gather 哪些 value 呢?这就要看其参数了。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

所以一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

具体来说,input 就是源 tensor,等会我们要在这个 tensor 上执行 gather 操作。如果 input 是一个一维数组,即 flat 列表,那么我们就可以直接根据 indexinput 上取了,就像正常的列表/数组索引一样。但是由于 input 可能含有多个维度,是 N 维数组,所以我们需要知道在哪个维度上进行 gather,这就是 dim 的作用。

对于 dim 参数,一种更为具体的理解方式是替换法。假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可,但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报 IndexError 。所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。如果 dim=0 ,我们就替换索引列表第 0 个值,即 [dim, j, k] ,依此类推。Pytorch 的官方文档的写法其实也是这个意思,但是看这么多个方括号可能会有点懵:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

但是可能你还有点迷糊,没关系接着看下面的直观理解部分,然后再回来看这段话,结合着看,相信你很快能明白。

由于我们是按照 index 来取值的,所以最终得到的 tensor 的 shape 也是和 index 一样的,就像我们在列表上按索引取值,得到的输出列表长度和索引相等一样。

直观理解

为便于理解,我们以一个具体例子来说明。我们使用反推法,根据 input 和输出推参数。这应该也是我们平常自己写代码的时候遇到比较多的情况。

假设 input 和我们想要的输出 output 如下:

>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
>>> output_tensor  # shape: (2, 2, 4)
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [20, 21, 22, 23]]])

即,我们想让 shape 为 (2, 3, 4)input_tensor 变成 shape 为 (2, 2, 4)output_tensor ,丢弃维度 1 的第 2 个元素,即 [ 4, 5, 6, 7][16, 17, 18, 19]

我们应用替换法,重点是找出来 dimindex 的值。始终记住 indexoutput_tensor 的 shape 是一样的。

output_tensor 的第一个位置开始,由于 output_tensor[0, 0, :] = input_tensor[0, 0, :] ,所以此时 [i, j, k] 是一样的,我们看不出来 dim 应该是多少。

下一行 output_tensor[0, 1, 0] = input_tensor[0, 2, 0] ,这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2, index_tensor[0, 1, 0]=2

此时 dim 已经明确。同理,output_tensor[0, 1, 1] = input_tensor[0, 2, 1]index_tensor[0, 1, 1]=2 ,依此类推,得到 index_tensor[0, 1, :] = 2 。同时也可以明确 index_tensor[0, 0, :] = 0

所以

>>> dim = 0
>>> index_tensor
tensor([[[0, 0, 0, 0],
         [2, 2, 2, 2]],

        [[0, 0, 0, 0],
         [2, 2, 2, 2]]])

简单可描述如下图:
torch.gather() 执行过程
为描述方便,假如我们把输入看作是 6 行,从上到下依次是 0-5。那么从事后诸葛亮的角度讲,输出相当于是把第 1 和第 4 行“抽掉”。如果输出和输入一样,那么原本的 index_tensor 就是如下:

tensor([[[0, 0, 0, 0],
         [1, 1, 1, 1],
         [2, 2, 2, 2]],

        [[0, 0, 0, 0],
         [1, 1, 1, 1],
         [2, 2, 2, 2]]])

“抽掉”后, index_tensor 也相应“抽掉”,那么就得到我们想要的结果了。而且由于这个“抽掉”的操作是在维度 1 上进行的,那么 dim 自然是 1。

numpy.take()tf.gather 貌似也是同样功能,就不细说了。

Reference

END

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-30 12:03:49  更:2021-08-30 12:06:24 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 21:39:53-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码