IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Transfomer矩阵维度分析及MultiHead详解 -> 正文阅读

[人工智能]Transfomer矩阵维度分析及MultiHead详解


解读Transformer就离不开下面这张图:

不同于之前的基于rnn的seq2seq模型,Transfomer完全摒弃了循环神经网络的结构:

  1. encoder层: {多头自注意力 + 前馈网络} × n \times n ×n
  2. decoder层: {Masked 多头自注意力 + encoder-decoder多头自注意力 + 前馈网络} × n \times n ×n

下面我们介绍Transformer模型中的矩阵维度变化情况:

矩阵维度分析

对于一个batch的数据,encoder端的输入大小为:(batch_size, sr_len);decoder端的输入大小为:(batch_size, tar_len)。不妨假设 encoder layer 及 decoder layer 都只有一层,下面是训练阶段的矩阵维度变化:

训练阶段

训练阶段 Encoder

input_sizeLayeroutput_sizeLayer_parameter_sizeNote
batch_size × \times × sr_lenInput Embeddingbatch_size × \times × sr_len × \times × embed_sizesr_vocab_size × \times × embed_sizeEmbedding层的参数即可设为可学习的,也可设为固定参数
batch_size × \times × sr_len × \times × embed_sizePostion Embeddingbatch_size × \times × sr_len × \times × embed_size1 × \times × sr_len × \times × embed_size固定参数
batch_size × \times × sr_len × \times × embed_sizeMultiHead Attentionbatch_size × \times × sr_len × \times × hidden_size{embed_size × \times × hidden_size} × \times × 3 + {hidden_size × \times × hidden_size}可学习参数
batch_size × \times × sr_len × \times × hidden_sizeAddNorm1batch_size × \times × sr_len × \times × hidden_sizeNone
batch_size × \times × sr_len × \times × hidden_sizeFeed Forwardbatch_size × \times × sr_len × \times × hidden_size{hidden_size × \times × filter_size} + {filter_size × \times × hidden_size}可学习参数
batch_size × \times × sr_len × \times × hidden_sizeAddNorm2batch_size × \times × sr_len × \times × hidden_sizeNone

训练阶段 Decoder

input_sizeLayeroutput_sizeLayer_parameter_sizeNote
batch_size × \times × tar_lenOutput Embeddingbatch_size × \times × tar_len × \times × embed_sizetar_vocab_size × \times × embed_sizeEmbedding层的参数即可设为可学习的,也可设为固定参数
batch_size × \times × tar_len × \times × embed_sizePostion Embeddingbatch_size × \times × tar_len × \times × embed_size1 × \times × tar_len × \times × embed_size固定参数
batch_size × \times × tar_len × \times × embed_sizeMasked MultiHead Attentionbatch_size × \times × tar_len × \times × hidden_size{embed_size × \times × hidden_size} × \times × 3 + {hidden_size × \times × hidden_size}可学习参数
batch_size × \times × tar_len × \times × hidden_sizeAddNorm1batch_size × \times × tar_len × \times × hidden_sizeNone
batch_size × \times × tar_len × \times × hidden_sizeEncoder-Decoder MultiHead Attentionbatch_size × \times × tar_len × \times × hidden_size{hidden_size × \times × hidden_size} × \times × 4可学习参数
batch_size × \times × tar_len × \times × hidden_sizeAddNorm2batch_size × \times × tar_len × \times × hidden_sizeNone
batch_size × \times × tar_len × \times × hidden_sizeFeed Forwardbatch_size × \times × tar_len × \times × hidden_size{hidden_size × \times × filter_size} + {filter_size × \times × hidden_size}可学习参数
batch_size × \times × tar_len × \times × hidden_sizeAddNorm3batch_size × \times × tar_len × \times × hidden_sizeNone

注意到,为了保持encoder及decoder的层可以堆叠,需要保证每个层的输入和输出的维度一致,因此,需要保证 embed_size = hidden_size


预测阶段

预测阶段的 encoder 与训练阶段是相同的,只是 batch_size = 1;而 decoder 部分由于每个 step 只能看到当前位置之前的信息,因此每次输入的 tar_len 也等于 1。

预测阶段 Decoder

input_sizeLayeroutput_size
1 × \times × 1Output Embedding1 × \times × 1 × \times × embed_size
1 × \times × 1 × \times × embed_sizePostion Embedding1 × \times × 1 × \times × embed_size
1 $\times$1 × \times × embed_sizeMasked MultiHead Attention1 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeAddNorm11 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeEncoder-Decoder MultiHead Attention1 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeAddNorm21 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeFeed Forward1 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeAddNorm31 × \times × 1 × \times × hidden_size

Multihead Attention解析

训练阶段

Encoder Multihead Attention

在这里插入图片描述

  1. Input: Encoder Multihead Attention 输入的 query, key, value 是相同的,都是经过了word embedding和pos embedding之后的 source sentence,其维度为 batch_size × sr_len × hidden_size \text{batch\_size} \times \text{sr\_len} \times \text{hidden\_size} batch_size×sr_len×hidden_size 。由于有 num_heads 个头需要并行计算,首先 query, key, value 分别经过一个线性变换,再将数据 split 给 num_heads 个头分别做注意力查询,即:
    q u e r y : batch_size × sr_len_q × hidden_size ? 线性变换 batch_size × sr_len_q × hidden_size ? reshape batch_size × num_heads × sr_len_q × hidden_size num_heads k e y : batch_size × sr_len_k × hidden_size ? 线性变换 batch_size × sr_len_k × hidden_size ? reshape batch_size × num_heads × sr_len_k × hidden_size num_heads v a l u e : batch_size × sr_len_v × hidden_size ? 线性变换 batch_size × sr_len_v × hidden_size ? reshape batch_size × num_heads × sr_len_v × hidden_size num_heads \begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{sr\_len\_q} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_q} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned} query:batch_size×sr_len_q×hidden_size?线性变换?batch_size×sr_len_q×hidden_size?reshape?batch_size×num_heads×sr_len_q×num_headshidden_size?key:batch_size×sr_len_k×hidden_size?线性变换?batch_size×sr_len_k×hidden_size?reshape?batch_size×num_heads×sr_len_k×num_headshidden_size?value:batch_size×sr_len_v×hidden_size?线性变换?batch_size×sr_len_v×hidden_size?reshape?batch_size×num_heads×sr_len_v×num_headshidden_size??

由于query, key, value 是相同的,因此有 sr_len_q = sr_len_k = sr_len_v

  1. DotProductAttention: num_heads 个头的计算是并行的,即:
    q u e r y : batch_size × num_heads × sr_len_q × hidden_size num_heads k e y : batch_size × num_heads × sr_len_k × hidden_size num_heads v a l u e : batch_size × num_heads × sr_len_v × hidden_size num_heads ? q u e r y ? k e y T = batch_size × num_heads × sr_len_q × sr_len_k ? 消 除 k e y 中 padding 的 影 响 , 对 其 做 mask masked_softmax ( q u e r y ? k e y T ) = batch_size × num_heads × sr_len_q × sr_len_k ? masked_softmax ( q u e r y ? k e y T ) ? v a l u e = batch_size × num_heads × sr_len_q × hidden_size num_heads \begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \Downarrow\\ \boldsymbol {query} * \boldsymbol {key}^T = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned} query:batch_size×num_headskey:batch_size×num_headsvalue:batch_size×num_heads?query?keyT=batch_size×num_heads?keypaddingmask?masked_softmax(query?keyT)=batch_size×num_heads?masked_softmax(query?keyT)?value=batch_size×num_heads?×sr_len_q×num_headshidden_size?×sr_len_k×num_headshidden_size?×sr_len_v×num_headshidden_size?×sr_len_q×sr_len_k×sr_len_q×sr_len_k×sr_len_q×num_headshidden_size??

Encoder Multihead Attention 中在计算 softmax 之前对 key 进行了 mask,目的是消除 padding 的影响。事实上 padding 不仅对 key 有影响,对 query 也有影响,但在实际代码中 mask 仅针对 key,而没有针对 query。其实最原始代码是既有 key mask,也有query mask的,但后来作者将 query mask 删去了,因为在最后计算 loss 的时候对 padding 位置的 loss 进行mask,也可达到相同的效果。

假设 batch_size = num_heads = 1,sr_len_q = sr_len_k = 6,source sentence 的最后两个位置是padding,那么Encoder Multihead Attention 中的 mask 为:
( 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 ) \begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0\\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \end{pmatrix} ?????????111111?111111?111111?111111?000000?000000??????????
即只对 key 的 padding 位置进行了 mask

  1. Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
    batch_size × num_heads × sr_len_q × hidden_size num_heads ? reshape batch_size × sr_len_q × hidden_size ? 线性变换 batch_size × sr_len_q × hidden_size \begin{aligned} \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \stackrel{\text{reshape}}{\Downarrow}\\ \text{batch\_size} \times \text{sr\_len\_q} &\times \text{hidden\_size}\\ \stackrel{\text{线性变换}}{\Downarrow}\\ \text{batch\_size} \times \text{sr\_len\_q} &\times \text{hidden\_size} \end{aligned} batch_size×num_heads?reshape?batch_size×sr_len_q?线性变换?batch_size×sr_len_q?×sr_len_q×num_headshidden_size?×hidden_size×hidden_size?

Masked Multihead Attention

与 Encoder Multihead Attention 类似,Masked Multihead Attention 输入的 query, key, value 也是相同的,都是经过了word embedding和pos embedding之后的 target sentence。包括后面的计算流程也基本一致。

主要的区别在于:由于在 inference 时,每个 step 位置只能看到它之前的 steps 的信息,而看不到它之后的 steps的信息。因此 Masked Multihead Attention 中的 mask 除了要消除 key 信息里 padding 的影响,还需要消除当前 step 后面的所有 step 的信息:

假设 batch_size = num_heads = 1,tar_len_q = tar_len_k = 5,target sentence 的最后两个位置是 padding,那么Masked Multihead Attention 中的 mask 为:
( 1 0 0 0 0 1 1 0 0 0 1 1 1 0 0 1 1 1 0 0 1 1 1 0 0 ) \begin{pmatrix} 1 & 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \end{pmatrix} ???????11111?01111?00111?00000?00000????????
注意到上述 mask 并不是一个单纯的下三角矩阵,因为最后两个位置都是padding,因此无论如何都要被 mask 掉


Encoder-Decoder Multihead Attention

  1. Input: Encoder-Decoder Multihead Attention 输入的 query 来自于 target sentence,其维度为 batch_size × tar_len × hidden_size \text{batch\_size} \times \text{tar\_len} \times \text{hidden\_size} batch_size×tar_len×hidden_size ;而 key 和 value 则来自于 encoder layer 的输出,其维度为 batch_size × sr_len × hidden_size \text{batch\_size} \times \text{sr\_len} \times \text{hidden\_size} batch_size×sr_len×hidden_size 。同样是先做线性变换,再 split 成 num_heads 个头:
    q u e r y : batch_size × tar_len_q × hidden_size ? 线性变换 batch_size × tar_len_q × hidden_size ? reshape batch_size × num_heads × tar_len_q × hidden_size num_heads k e y : batch_size × sr_len_k × hidden_size ? 线性变换 batch_size × sr_len_k × hidden_size ? reshape batch_size × num_heads × sr_len_k × hidden_size num_heads v a l u e : batch_size × sr_len_v × hidden_size ? 线性变换 batch_size × sr_len_v × hidden_size ? reshape batch_size × num_heads × sr_len_v × hidden_size num_heads \begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{tar\_len\_q} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{tar\_len\_q} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned} query:batch_size×tar_len_q×hidden_size?线性变换?batch_size×tar_len_q×hidden_size?reshape?batch_size×num_heads×tar_len_q×num_headshidden_size?key:batch_size×sr_len_k×hidden_size?线性变换?batch_size×sr_len_k×hidden_size?reshape?batch_size×num_heads×sr_len_k×num_headshidden_size?value:batch_size×sr_len_v×hidden_size?线性变换?batch_size×sr_len_v×hidden_size?reshape?batch_size×num_heads×sr_len_v×num_headshidden_size??

这里 sr_len_q ≠ \neq ?= sr_len_k = sr_len_v

  1. DotProductAttention: num_heads 个头的计算依然可以并行:
    q u e r y ? k e y T = batch_size × num_heads × tar_len_q × sr_len_k ? 消 除 k e y 中 padding 的 影 响 , 对 其 做 mask masked_softmax ( q u e r y ? k e y T ) = batch_size × num_heads × tar_len_q × sr_len_k ? masked_softmax ( q u e r y ? k e y T ) ? v a l u e = batch_size × num_heads × tar_len_q × hidden_size num_heads \begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned} query?keyT=batch_size×num_heads?keypaddingmask?masked_softmax(query?keyT)=batch_size×num_heads?masked_softmax(query?keyT)?value=batch_size×num_heads?×tar_len_q×sr_len_k×tar_len_q×sr_len_k×tar_len_q×num_headshidden_size??

假设 batch_size = num_heads = 1,这里sr_len_q可以不等于sr_len_k,不妨假设 sr_len_q = 5,sr_len_k = 6因为 mask 只针对key,因此这里只需要关注 source sentence 中的padding, 假设 source sentence 的最后两个位置是padding,那么Masked Multihead Attention 中的 mask 为:
( 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 ) \begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \end{pmatrix} ???????11111?11111?11111?11111?00000?00000????????

  1. Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
    batch_size × num_heads × tar_len_q × hidden_size num_heads ? reshape batch_size × tar_len_q × hidden_size ? 线性变换 batch_size × tar_len_q × hidden_size \begin{aligned} \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \stackrel{\text{reshape}}{\Downarrow}\\ \text{batch\_size} \times \text{tar\_len\_q} &\times \text{hidden\_size}\\ \stackrel{\text{线性变换}}{\Downarrow}\\ \text{batch\_size} \times \text{tar\_len\_q} &\times \text{hidden\_size} \end{aligned} batch_size×num_heads?reshape?batch_size×tar_len_q?线性变换?batch_size×tar_len_q?×tar_len_q×num_headshidden_size?×hidden_size×hidden_size?

预测阶段

Encoder Multihead Attention

与训练阶段的 Encoder Multihead Attention 完全相同

Masked Multihead Attention

虽然在训练阶段,Masked Multihead Attention 会将当前 step 之后的 steps 信息都 mask 掉,但是由于训练时整个 target sentence 都是已知的,因此还是可以做并行运算的。

但是在预测阶段,初始的 query, key, value 都只是一个 “<bos>” 起始符号,之后每预测出一个 token,这个 token 直接作为下一个 step 输入的 query,而将这个 token 拼在现有的 key 和 value 之后,就是下一个 step 输入的 key 和 value。也就是说,预测阶段每个 step 输入的 query 是上一 step 输出的token,而 key, value 是之前所有 steps 输出的token

至于 mask 的部分,由于输入中不再含有未来 steps 的信息,因此不再需要用 mask 来消除这部分信息。而对于 key mask,由于 Masked Multihead Attention 的 key 是 target sentence,而在预测完成前 target sentence 的长度是未知的,因此针对 key 的 mask 也是不需要的也就是说,Masked Multihead Attention 是不需要 mask 的

下面是预测阶段 Masked Multihead Attention 的流程:

  1. Input: key, value 是到当前 step 为止的所有 steps 的信息,大小为 1 × \times × cur_tar_len × \times × hidden_size;而 query 是上一 step 的输出 token,大小为 1 × \times × 1 × \times × hidden_size:
    q u e r y : 1 × 1 × hidden_size ? 线性变换 1 × 1 × hidden_size ? reshape 1 × num_heads × 1 × hidden_size num_heads k e y : 1 × cur_tar_len_k × hidden_size ? 线性变换 1 × cur_tar_len_k × hidden_size ? reshape 1 × num_heads × cur_tar_len_k × hidden_size num_heads v a l u e : 1 × cur_tar_len_v × hidden_size ? 线性变换 1 × cur_tar_len_v × hidden_size ? reshape 1 × num_heads × cur_tar_len_v × hidden_size num_heads \begin{aligned} \boldsymbol {query}&: \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}&: \text{1} \times \text{cur\_tar\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{cur\_tar\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{cur\_tar\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}&: \text{1} \times \text{cur\_tar\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{cur\_tar\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{cur\_tar\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned} querykeyvalue?:1×1×hidden_size?线性变换?1×1×hidden_size?reshape?1×num_heads×1×num_headshidden_size?:1×cur_tar_len_k×hidden_size?线性变换?1×cur_tar_len_k×hidden_size?reshape?1×num_heads×cur_tar_len_k×num_headshidden_size?:1×cur_tar_len_v×hidden_size?线性变换?1×cur_tar_len_v×hidden_size?reshape?1×num_heads×cur_tar_len_v×num_headshidden_size??

  2. DotProductAttention: num_heads 个头的计算依然可以并行:
    q u e r y ? k e y T = 1 × num_heads × 1 × cur_tar_len_k ? softmax ( q u e r y ? k e y T ) = 1 × num_heads × 1 × cur_tar_len_k ? softmax ( q u e r y ? k e y T ) ? v a l u e = 1 × num_heads × 1 × hidden_size num_heads \begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{cur\_tar\_len\_k}\\ \Downarrow\\ \text{softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{cur\_tar\_len\_k}\\ \Downarrow\\ \text{softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned} query?keyT=1×num_heads?softmax(query?keyT)=1×num_heads?softmax(query?keyT)?value=1×num_heads?×1×cur_tar_len_k×1×cur_tar_len_k×1×num_headshidden_size??

  3. Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
    1 × num_heads × 1 × hidden_size num_heads ? reshape 1 × 1 × hidden_size ? 线性变换 1 × 1 × hidden_size \begin{aligned} \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ &\stackrel{\text{reshape}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size}\\ &\stackrel{\text{线性变换}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size} \end{aligned} 1×num_heads1×11×1?×1×num_headshidden_size??reshape?×hidden_size?线性变换?×hidden_size?


Encoder-Decoder Multihead Attention

预测阶段 Encoder-Decoder Multihead Attention 输入的 query 是上一层 Masked Multihead Attention 的输出,大小为 1 × \times × 1 × \times × hidden_size。而输入的 key 和 value 则是 encoder layer 的输出,大小为:1 × \times × sr_len × \times × hidden_size。具体流程为:

  1. Input:
    q u e r y : 1 × 1 × hidden_size ? 线性变换 1 × 1 × hidden_size ? reshape 1 × num_heads × 1 × hidden_size num_heads k e y : 1 × sr_len_k × hidden_size ? 线性变换 1 × sr_len_k × hidden_size ? reshape 1 × num_heads × sr_len_k × hidden_size num_heads v a l u e : 1 × sr_len_v × hidden_size ? 线性变换 1 × sr_len_v × hidden_size ? reshape 1 × num_heads × sr_len_v × hidden_size num_heads \begin{aligned} \boldsymbol {query}&: \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}&: \text{1} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}&: \text{1} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned} querykeyvalue?:1×1×hidden_size?线性变换?1×1×hidden_size?reshape?1×num_heads×1×num_headshidden_size?:1×sr_len_k×hidden_size?线性变换?1×sr_len_k×hidden_size?reshape?1×num_heads×sr_len_k×num_headshidden_size?:1×sr_len_v×hidden_size?线性变换?1×sr_len_v×hidden_size?reshape?1×num_heads×sr_len_v×num_headshidden_size??

  2. DotProductAttention: num_heads 个头的计算并行:
    q u e r y ? k e y T = 1 × num_heads × 1 × sr_len_k ? 消 除 k e y 中 padding 的 影 响 , 对 其 做 mask masked_softmax ( q u e r y ? k e y T ) = 1 × num_heads × 1 × sr_len_k ? masked_softmax ( q u e r y ? k e y T ) ? v a l u e = 1 × num_heads × 1 × hidden_size num_heads \begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned} query?keyT=1×num_heads?keypaddingmask?masked_softmax(query?keyT)=1×num_heads?masked_softmax(query?keyT)?value=1×num_heads?×1×sr_len_k×1×sr_len_k×1×num_headshidden_size??

假设 num_heads = 1,sr_len_k = 6,因为 mask 只针对key,因此这里只需要关注 source sentence 中的padding, 假设 source sentence 的最后两个位置是padding,那么Masked Multihead Attention 中的 mask 为:
( 1 1 1 1 0 0 ) \begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0 \\ \end{pmatrix} (1?1?1?1?0?0?)

  1. Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
    1 × num_heads × 1 × hidden_size num_heads ? reshape 1 × 1 × hidden_size ? 线性变换 1 × 1 × hidden_size \begin{aligned} \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ &\stackrel{\text{reshape}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size}\\ &\stackrel{\text{线性变换}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size} \end{aligned} 1×num_heads1×11×1?×1×num_headshidden_size??reshape?×hidden_size?线性变换?×hidden_size?

由于最后的 Feed Forward 层不改变矩阵大小,至此可以总结一下预测阶段的 Decoder layer,输入是上一 step 输出的 token,大小为 1 × \times × 1 × \times × hidden_size,经过两种MultiHead + Feed Forward 后,大小依然为 1 × \times × 1 × \times × hidden_size,再经过 Linear+Softmax,其输出就是预测的当前 step 的token,而这个 token 又会作为下一个 step 的输入 query。直到达到最大长度,或者输出的 token 是 “<eos>”

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-22 20:35:19  更:2022-03-22 20:37:31 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 13:47:00-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码