IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【论文学习】Transformer-XL -> 正文阅读

[人工智能]【论文学习】Transformer-XL

简介

RNN及其变体是训练语言模型(Language Modeling)的经典结构,其优点就是能够学习到序列之间的依赖关系,缺点:1)随着序列长度的增加,序列之间的依赖关系信息会逐渐丢失;2)单向;3)计算速度慢,只能step by step。截止到Transformer-XL,单向学习似乎是LM任务绕不过去的坎,XLNet使用PLM才比较隐蔽的解决单向学习语言模型的问题。

为了解决RNN存在的问题,Transformer-XL(XL表示extra long)沿用Transformer 中的“注意力机制”,但只是用Decoder部分学习语言模型。为了使用注意力机制学习语言模型, 1)继续沿用Mask Attention方法,使得预测当前词时,只能看到之前词的信息;2)引入“循环机制”、“相对位置编码”,提高了处理长文本的能力。第二点是本文的主要创新点。
?

模型

Vanilla Transformer

Transformer-XL引入的循环机制,参考了Vanilla Transformer的思想,因此先介绍下Vanilla Transformer模型。Vanilla对于长输入,在训练阶段,会将输入割裂成独立的几个部分(segment),然后分别处理。在推断阶段,每次取segment长度的输入进行处理,输出一个预测词,然后向右移动一个位置。如下图所示:
Vanila
这样做有明显的缺点:1)编码或者推断过程,最多只能看到segment length的信息,当原始输入中存在长距离依赖关系时,会学习不足;2)割裂处理每一部分,产生了碎片化问题;3)在推断阶段,每次都从头计算,计算效率非常低。
?

recurrence mechanism

Transformer-XL中的循环机制的粒度是“segment”(后面称为切片),不像是RNN中以字为粒度的。另外一个显著不同点是,计算第 i + 1 i+1 i+1切片的第 N N N层隐向量时,需要用到第 i i i 个切片的第 N ? 1 N-1 N?1层隐向量信息、以及第 i + 1 i+1 i+1切片的第 N ? 1 N-1 N?1层隐向量。注意是“前一个切片的下一层的隐向量”,如下所示:
transformer-XL
即对于第 i + 1 i+1 i+1切片的最后一层隐向量的最后一个位置,可以看到 N ? L N* L N?L范围的信息, L L L表示切片长度。 引入循环机制,使得信息可以在不同的segment之间流动,避免了碎片化。
?

相对位置编码

标准Transformer采用的是sin-cos相对位置编码(多头注意力机制是没有考虑输入序列的位置信息的,因此必须额外的引入位置信息,否则就类似词袋模型了),但由于没有引用“循环机制”,因此可以不采用相对位置的编码方式。而Transformer-XL由于引入了循环机制,因此必须考虑切片之间的相对位置信息,否则每一切片的同一位置信息是一致的,这显然不合理。假设某一切片内存在 i , j i, j i,j两个位置,则这两个位置的:
a i , j a b s = q i T ? k j = ( W q ? ( E ( x i ) + U i ) T ? ( W k ? ( E ( x j ) + U j ) ) = E x i T ? W q T ? W k ? E x j + E x i T ? W q T ? W k ? U j + U i T ? W q T ? W k ? E x j + U i T ? W q T ? W k ? U j a_{i,j}^{abs} = q_i^T * k_j = (W_q * (E(x_i) + U_i)^T * (W_k * (E(x_j) + U_j)) \newline = E_{x_i}^T * W_q^T * W_k * E_{x_j} \newline + E_{x_i}^T * W_q^T * W_k * U_j \newline + U_i^T * W_q^T * W_k * E_{x_j} \newline + U_i^T * W_q^T * W_k * U_j ai,jabs?=qiT??kj?=(Wq??(E(xi?)+Ui?)T?(Wk??(E(xj?)+Uj?))=Exi?T??WqT??Wk??Exj??+Exi?T??WqT??Wk??Uj?+UiT??WqT??Wk??Exj??+UiT??WqT??Wk??Uj?

其中 U U U表示PositionalEmbedding矩阵,该矩阵就是Transformer中使用的PE,是不需要学习的。论文对上面的四个子项进行优化,优化后如下
在这里插入图片描述
四项的含义可以参看原文,其中 u T u^T uT v T v^T vT W k , R W_{k, R} Wk,R?是学习参数, R R R矩阵是PE矩阵,不需要学习。对上式重写:
A i , j r e l = q i T k j + q i T W k , R R i ? j + u T k j + v T W k , R R i ? j A_{i,j}^{rel} = q_i^Tk_j + q_i^TW_{k,R}R_{i-j} + u^Tk_j + v^T W_{k, R}R_{i-j} Ai,jrel?=qiT?kj?+qiT?Wk,R?Ri?j?+uTkj?+vTWk,R?Ri?j?

再将(a)、(c)项合并, (b)、(d)项合并,如下:
A i , j r e l = ( q i T + u T ) k j + ( q i T + v T ) W k , R R i ? j A_{i,j}^{rel} = (q_i^T+u^T)k_j + (q_i^T + v^T)W_{k,R}R_{i-j} Ai,jrel?=(qiT?+uT)kj?+(qiT?+vT)Wk,R?Ri?j?

第一项中不涉及相对位置信息,直接进行矩阵计算就可以,第二项中由于包含 R i ? j R_{i-j} Ri?j?,因此需要进行相对位置转换。

对于 q i T W k , R R i ? j q_i^T W_{k,R}R_{i-j} qiT?Wk,R?Ri?j?项,假设当前输入段的长度为 L L L,缓存的Memory长度为 M M M,则该项的shape为 L ? ( M + L ) L * (M+L) L?(M+L) R i ? j R_{i-j} Ri?j?取值范围为 [ 0 , L + M ? 1 ] [0, L+M-1] [0,L+M?1]。 当 R i ? j R_{i-j} Ri?j?逆序取值时:

Q k T = W k , R R M + L ? 1 ? k Q_k^T = W_{k, R}R_{M+L-1-k} QkT?=Wk,R?RM+L?1?k? Q Q Q矩阵中的每一行表示一个相对位置:

Q = [ R M + L ? 1 T R M + L ? 2 T . . . R 0 T ] W k , R T Q =\begin{bmatrix} R_{M+L-1}^T\\ R_{M+L-2}^T\\ ...\\ R_{0}^T \end{bmatrix} W_{k,R}^T Q=?????RM+L?1T?RM+L?2T?...R0T???????Wk,RT?
转换后

Q T = [ W k , R R M + L ? 1 W k , R R M + L ? 2 . . . W k , R R 0 ] Q^T = \begin{bmatrix}W_{k,R}R_{M+L-1}\\ W_{k,R}R_{M+L-2}\\ ...\\ W_{k,R}R_{0} \end{bmatrix} QT=?????Wk,R?RM+L?1?Wk,R?RM+L?2?...Wk,R?R0???????

B = q Q T = [ q 0 Q 0 ? q 0 Q M q 0 Q M + 1 ? q 0 Q L + M ? 1 q 1 Q 0 ? q 1 Q M q 1 Q M + 1 ? q 1 Q L + M ? 1 ? ? ? ? ? ? q L ? 1 Q 0 ? q L ? 1 Q M q L ? 1 Q M + 1 ? q L ? 1 Q M + L ? 1 ] B=qQ^T =\begin{bmatrix} q_0Q_0 & \cdots & q_0Q_M & q_0Q_{M+1} & \cdots & q_0Q_{L+M-1}\\ q_1Q_0 & \cdots & q_1Q_M & q_1Q_{M+1} & \cdots &q_1Q_{L+M-1}\\ \vdots&\ddots&\vdots &\vdots&\ddots&\vdots\\ q_{L-1}Q_0 & \cdots & q_{L-1}Q_M & q_{L-1}Q_{M+1} & \cdots & q_{L-1}Q_{M+L-1} \end{bmatrix} B=qQT=??????q0?Q0?q1?Q0??qL?1?Q0???????q0?QM?q1?QM??qL?1?QM??q0?QM+1?q1?QM+1??qL?1?QM+1???????q0?QL+M?1?q1?QL+M?1??qL?1?QM+L?1????????

显然B不是最终要求的(b)项,但存在很强的关联。对于第一行 q 0 q_0 q0?,即 L L L段的第一个元素,它与 L + M L+M L+M中所有元素的相对位置关系为:“ M M M、…、 1 1 1 0 0 0 L ? 1 L-1 L?1”。 同理第二行的相对位置关系为“ M + 1 M+1 M+1、…、 2 2 2 1 1 1 L ? 2 L-2 L?2”, 最后一个元素 q L ? 1 q_{L-1} qL?1?的相对位置关系为“ M + L ? 1 M+L-1 M+L?1、…、 L L L L ? 1 L-1 L?1 0 0 0”。即要求的(b)项为:
B s h i f t = [ q 0 Q L ? 1 ? q 0 Q M + L ? 2 q 0 Q M + L ? 1 ? 0 q 1 Q L ? 2 ? q 0 Q M + L ? 3 q 0 Q M + L ? 2 q 1 Q M + L ? 1 ? ? ? ? ? ? ? q L ? 1 Q 0 ? q L ? 1 Q M q L ? 1 Q M + 1 ? q L ? 1 Q M + L ? 1 ] B_{shift} = \begin{bmatrix}q_0Q_{L-1} & \cdots &q_0Q_{M+L-2} & q_0Q_{M+L-1} & \cdots & 0 \\ q_1Q_{L-2} & \cdots &q_0Q_{M+L-3} & q_0Q_{M+L-2} & q_1Q_{M+L-1} & \cdots \\ \vdots & \ddots & \vdots & \vdots & \vdots & \ddots \\ q_{L-1}Q_{0} & \cdots & q_{L-1}Q_M & q_{L-1}Q_{M+1} & \cdots & q_{L-1}Q_{M+L-1} \end{bmatrix} Bshift?=??????q0?QL?1?q1?QL?2??qL?1?Q0???????q0?QM+L?2?q0?QM+L?3??qL?1?QM??q0?QM+L?1?q0?QM+L?2??qL?1?QM+1???q1?QM+L?1????0??qL?1?QM+L?1????????

注意当前单词只能看到之前出现的信息,不能看到之后的单词,因此有些位置最终的注意力得分为0。那么如何得到 B s h i f t B_{shift} Bshift?,只需对 B B B进行如下转换:

    def _rel_shift(self, x, zero_triu=False):
        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                               device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=1)

        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

        x = x_padded[1:].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x

这个技巧性很强,需要结合参考资料知乎上的图例加以理解。而对于D项,在实践上先和查询矩阵求和,在进行矩阵乘法、位置转换。 BD项计算完成之后,与AC项相加,得到最终的注意力得分。然后进行后面的加权求和等操作就可以预测下一个位置了。

?

参考资料

  1. 论文 《Transformer-XL: Attentive Language ModelsBeyond a Fixed-Length Context》
  2. Transformer-XL解读(论文 + PyTorch源码)https://blog.csdn.net/qq_22795223/article/details/106130388
  3. 【核心代码解读】Transformer-XL https://zhuanlan.zhihu.com/p/74485142
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-02 11:21:53  更:2021-09-02 11:22:43 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 20:36:09-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码