解读Transformer就离不开下面这张图:
不同于之前的基于rnn的seq2seq模型,Transfomer完全摒弃了循环神经网络的结构:
- encoder层: {多头自注意力 + 前馈网络}
×
n
\times n
×n
- decoder层: {Masked 多头自注意力 + encoder-decoder多头自注意力 + 前馈网络}
×
n
\times n
×n
下面我们介绍Transformer模型中的矩阵维度变化情况:
矩阵维度分析
对于一个batch的数据,encoder端的输入大小为:(batch_size, sr_len);decoder端的输入大小为:(batch_size, tar_len)。不妨假设 encoder layer 及 decoder layer 都只有一层,下面是训练阶段的矩阵维度变化:
训练阶段
训练阶段 Encoder
input_size | Layer | output_size | Layer_parameter_size | Note |
---|
batch_size
×
\times
× sr_len | Input Embedding | batch_size
×
\times
× sr_len
×
\times
× embed_size | sr_vocab_size
×
\times
× embed_size | Embedding层的参数即可设为可学习的,也可设为固定参数 | batch_size
×
\times
× sr_len
×
\times
× embed_size | Postion Embedding | batch_size
×
\times
× sr_len
×
\times
× embed_size | 1
×
\times
× sr_len
×
\times
× embed_size | 固定参数 | batch_size
×
\times
× sr_len
×
\times
× embed_size | MultiHead Attention | batch_size
×
\times
× sr_len
×
\times
× hidden_size | {embed_size
×
\times
× hidden_size}
×
\times
× 3 + {hidden_size
×
\times
× hidden_size} | 可学习参数 | batch_size
×
\times
× sr_len
×
\times
× hidden_size | AddNorm1 | batch_size
×
\times
× sr_len
×
\times
× hidden_size | None | | batch_size
×
\times
× sr_len
×
\times
× hidden_size | Feed Forward | batch_size
×
\times
× sr_len
×
\times
× hidden_size | {hidden_size
×
\times
× filter_size} + {filter_size
×
\times
× hidden_size} | 可学习参数 | batch_size
×
\times
× sr_len
×
\times
× hidden_size | AddNorm2 | batch_size
×
\times
× sr_len
×
\times
× hidden_size | None | |
训练阶段 Decoder
input_size | Layer | output_size | Layer_parameter_size | Note |
---|
batch_size
×
\times
× tar_len | Output Embedding | batch_size
×
\times
× tar_len
×
\times
× embed_size | tar_vocab_size
×
\times
× embed_size | Embedding层的参数即可设为可学习的,也可设为固定参数 | batch_size
×
\times
× tar_len
×
\times
× embed_size | Postion Embedding | batch_size
×
\times
× tar_len
×
\times
× embed_size | 1
×
\times
× tar_len
×
\times
× embed_size | 固定参数 | batch_size
×
\times
× tar_len
×
\times
× embed_size | Masked MultiHead Attention | batch_size
×
\times
× tar_len
×
\times
× hidden_size | {embed_size
×
\times
× hidden_size}
×
\times
× 3 + {hidden_size
×
\times
× hidden_size} | 可学习参数 | batch_size
×
\times
× tar_len
×
\times
× hidden_size | AddNorm1 | batch_size
×
\times
× tar_len
×
\times
× hidden_size | None | | batch_size
×
\times
× tar_len
×
\times
× hidden_size | Encoder-Decoder MultiHead Attention | batch_size
×
\times
× tar_len
×
\times
× hidden_size | {hidden_size
×
\times
× hidden_size}
×
\times
× 4 | 可学习参数 | batch_size
×
\times
× tar_len
×
\times
× hidden_size | AddNorm2 | batch_size
×
\times
× tar_len
×
\times
× hidden_size | None | | batch_size
×
\times
× tar_len
×
\times
× hidden_size | Feed Forward | batch_size
×
\times
× tar_len
×
\times
× hidden_size | {hidden_size
×
\times
× filter_size} + {filter_size
×
\times
× hidden_size} | 可学习参数 | batch_size
×
\times
× tar_len
×
\times
× hidden_size | AddNorm3 | batch_size
×
\times
× tar_len
×
\times
× hidden_size | None | |
注意到,为了保持encoder及decoder的层可以堆叠,需要保证每个层的输入和输出的维度一致,因此,需要保证 embed_size = hidden_size
预测阶段
预测阶段的 encoder 与训练阶段是相同的,只是 batch_size = 1;而 decoder 部分由于每个 step 只能看到当前位置之前的信息,因此每次输入的 tar_len 也等于 1。
预测阶段 Decoder
input_size | Layer | output_size |
---|
1
×
\times
× 1 | Output Embedding | 1
×
\times
× 1
×
\times
× embed_size | 1
×
\times
× 1
×
\times
× embed_size | Postion Embedding | 1
×
\times
× 1
×
\times
× embed_size | 1 $\times$1
×
\times
× embed_size | Masked MultiHead Attention | 1
×
\times
× 1
×
\times
× hidden_size | 1
×
\times
× 1
×
\times
× hidden_size | AddNorm1 | 1
×
\times
× 1
×
\times
× hidden_size | 1
×
\times
× 1
×
\times
× hidden_size | Encoder-Decoder MultiHead Attention | 1
×
\times
× 1
×
\times
× hidden_size | 1
×
\times
× 1
×
\times
× hidden_size | AddNorm2 | 1
×
\times
× 1
×
\times
× hidden_size | 1
×
\times
× 1
×
\times
× hidden_size | Feed Forward | 1
×
\times
× 1
×
\times
× hidden_size | 1
×
\times
× 1
×
\times
× hidden_size | AddNorm3 | 1
×
\times
× 1
×
\times
× hidden_size |
Multihead Attention解析
训练阶段
Encoder Multihead Attention
- Input: Encoder Multihead Attention 输入的 query, key, value 是相同的,都是经过了word embedding和pos embedding之后的 source sentence,其维度为
batch_size
×
sr_len
×
hidden_size
\text{batch\_size} \times \text{sr\_len} \times \text{hidden\_size}
batch_size×sr_len×hidden_size 。由于有 num_heads 个头需要并行计算,首先 query, key, value 分别经过一个线性变换,再将数据 split 给 num_heads 个头分别做注意力查询,即:
q
u
e
r
y
:
batch_size
×
sr_len_q
×
hidden_size
?
线性变换
batch_size
×
sr_len_q
×
hidden_size
?
reshape
batch_size
×
num_heads
×
sr_len_q
×
hidden_size
num_heads
k
e
y
:
batch_size
×
sr_len_k
×
hidden_size
?
线性变换
batch_size
×
sr_len_k
×
hidden_size
?
reshape
batch_size
×
num_heads
×
sr_len_k
×
hidden_size
num_heads
v
a
l
u
e
:
batch_size
×
sr_len_v
×
hidden_size
?
线性变换
batch_size
×
sr_len_v
×
hidden_size
?
reshape
batch_size
×
num_heads
×
sr_len_v
×
hidden_size
num_heads
\begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{sr\_len\_q} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_q} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned}
query:batch_size×sr_len_q×hidden_size?线性变换?batch_size×sr_len_q×hidden_size?reshape?batch_size×num_heads×sr_len_q×num_headshidden_size?key:batch_size×sr_len_k×hidden_size?线性变换?batch_size×sr_len_k×hidden_size?reshape?batch_size×num_heads×sr_len_k×num_headshidden_size?value:batch_size×sr_len_v×hidden_size?线性变换?batch_size×sr_len_v×hidden_size?reshape?batch_size×num_heads×sr_len_v×num_headshidden_size??
由于query, key, value 是相同的,因此有 sr_len_q = sr_len_k = sr_len_v
- DotProductAttention: num_heads 个头的计算是并行的,即:
q
u
e
r
y
:
batch_size
×
num_heads
×
sr_len_q
×
hidden_size
num_heads
k
e
y
:
batch_size
×
num_heads
×
sr_len_k
×
hidden_size
num_heads
v
a
l
u
e
:
batch_size
×
num_heads
×
sr_len_v
×
hidden_size
num_heads
?
q
u
e
r
y
?
k
e
y
T
=
batch_size
×
num_heads
×
sr_len_q
×
sr_len_k
?
消
除
k
e
y
中
padding
的
影
响
,
对
其
做
mask
masked_softmax
(
q
u
e
r
y
?
k
e
y
T
)
=
batch_size
×
num_heads
×
sr_len_q
×
sr_len_k
?
masked_softmax
(
q
u
e
r
y
?
k
e
y
T
)
?
v
a
l
u
e
=
batch_size
×
num_heads
×
sr_len_q
×
hidden_size
num_heads
\begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \Downarrow\\ \boldsymbol {query} * \boldsymbol {key}^T = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned}
query:batch_size×num_headskey:batch_size×num_headsvalue:batch_size×num_heads?query?keyT=batch_size×num_heads?消除key中padding的影响,对其做mask?masked_softmax(query?keyT)=batch_size×num_heads?masked_softmax(query?keyT)?value=batch_size×num_heads?×sr_len_q×num_headshidden_size?×sr_len_k×num_headshidden_size?×sr_len_v×num_headshidden_size?×sr_len_q×sr_len_k×sr_len_q×sr_len_k×sr_len_q×num_headshidden_size??
Encoder Multihead Attention 中在计算 softmax 之前对 key 进行了 mask,目的是消除 padding 的影响。事实上 padding 不仅对 key 有影响,对 query 也有影响,但在实际代码中 mask 仅针对 key,而没有针对 query。其实最原始代码是既有 key mask,也有query mask的,但后来作者将 query mask 删去了,因为在最后计算 loss 的时候对 padding 位置的 loss 进行mask,也可达到相同的效果。
假设 batch_size = num_heads = 1,sr_len_q = sr_len_k = 6,source sentence 的最后两个位置是padding,那么Encoder Multihead Attention 中的 mask 为:
(
1
1
1
1
0
0
1
1
1
1
0
0
1
1
1
1
0
0
1
1
1
1
0
0
1
1
1
1
0
0
1
1
1
1
0
0
)
\begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0\\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \end{pmatrix}
?????????111111?111111?111111?111111?000000?000000?????????? 即只对 key 的 padding 位置进行了 mask
- Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
batch_size
×
num_heads
×
sr_len_q
×
hidden_size
num_heads
?
reshape
batch_size
×
sr_len_q
×
hidden_size
?
线性变换
batch_size
×
sr_len_q
×
hidden_size
\begin{aligned} \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \stackrel{\text{reshape}}{\Downarrow}\\ \text{batch\_size} \times \text{sr\_len\_q} &\times \text{hidden\_size}\\ \stackrel{\text{线性变换}}{\Downarrow}\\ \text{batch\_size} \times \text{sr\_len\_q} &\times \text{hidden\_size} \end{aligned}
batch_size×num_heads?reshape?batch_size×sr_len_q?线性变换?batch_size×sr_len_q?×sr_len_q×num_headshidden_size?×hidden_size×hidden_size?
Masked Multihead Attention
与 Encoder Multihead Attention 类似,Masked Multihead Attention 输入的 query, key, value 也是相同的,都是经过了word embedding和pos embedding之后的 target sentence。包括后面的计算流程也基本一致。
主要的区别在于:由于在 inference 时,每个 step 位置只能看到它之前的 steps 的信息,而看不到它之后的 steps的信息。因此 Masked Multihead Attention 中的 mask 除了要消除 key 信息里 padding 的影响,还需要消除当前 step 后面的所有 step 的信息:
假设 batch_size = num_heads = 1,tar_len_q = tar_len_k = 5,target sentence 的最后两个位置是 padding,那么Masked Multihead Attention 中的 mask 为:
(
1
0
0
0
0
1
1
0
0
0
1
1
1
0
0
1
1
1
0
0
1
1
1
0
0
)
\begin{pmatrix} 1 & 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \end{pmatrix}
???????11111?01111?00111?00000?00000???????? 注意到上述 mask 并不是一个单纯的下三角矩阵,因为最后两个位置都是padding,因此无论如何都要被 mask 掉
Encoder-Decoder Multihead Attention
- Input: Encoder-Decoder Multihead Attention 输入的 query 来自于 target sentence,其维度为
batch_size
×
tar_len
×
hidden_size
\text{batch\_size} \times \text{tar\_len} \times \text{hidden\_size}
batch_size×tar_len×hidden_size ;而 key 和 value 则来自于 encoder layer 的输出,其维度为
batch_size
×
sr_len
×
hidden_size
\text{batch\_size} \times \text{sr\_len} \times \text{hidden\_size}
batch_size×sr_len×hidden_size 。同样是先做线性变换,再 split 成 num_heads 个头:
q
u
e
r
y
:
batch_size
×
tar_len_q
×
hidden_size
?
线性变换
batch_size
×
tar_len_q
×
hidden_size
?
reshape
batch_size
×
num_heads
×
tar_len_q
×
hidden_size
num_heads
k
e
y
:
batch_size
×
sr_len_k
×
hidden_size
?
线性变换
batch_size
×
sr_len_k
×
hidden_size
?
reshape
batch_size
×
num_heads
×
sr_len_k
×
hidden_size
num_heads
v
a
l
u
e
:
batch_size
×
sr_len_v
×
hidden_size
?
线性变换
batch_size
×
sr_len_v
×
hidden_size
?
reshape
batch_size
×
num_heads
×
sr_len_v
×
hidden_size
num_heads
\begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{tar\_len\_q} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{tar\_len\_q} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned}
query:batch_size×tar_len_q×hidden_size?线性变换?batch_size×tar_len_q×hidden_size?reshape?batch_size×num_heads×tar_len_q×num_headshidden_size?key:batch_size×sr_len_k×hidden_size?线性变换?batch_size×sr_len_k×hidden_size?reshape?batch_size×num_heads×sr_len_k×num_headshidden_size?value:batch_size×sr_len_v×hidden_size?线性变换?batch_size×sr_len_v×hidden_size?reshape?batch_size×num_heads×sr_len_v×num_headshidden_size??
这里 sr_len_q
≠
\neq
?= sr_len_k = sr_len_v
- DotProductAttention: num_heads 个头的计算依然可以并行:
q
u
e
r
y
?
k
e
y
T
=
batch_size
×
num_heads
×
tar_len_q
×
sr_len_k
?
消
除
k
e
y
中
padding
的
影
响
,
对
其
做
mask
masked_softmax
(
q
u
e
r
y
?
k
e
y
T
)
=
batch_size
×
num_heads
×
tar_len_q
×
sr_len_k
?
masked_softmax
(
q
u
e
r
y
?
k
e
y
T
)
?
v
a
l
u
e
=
batch_size
×
num_heads
×
tar_len_q
×
hidden_size
num_heads
\begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned}
query?keyT=batch_size×num_heads?消除key中padding的影响,对其做mask?masked_softmax(query?keyT)=batch_size×num_heads?masked_softmax(query?keyT)?value=batch_size×num_heads?×tar_len_q×sr_len_k×tar_len_q×sr_len_k×tar_len_q×num_headshidden_size??
假设 batch_size = num_heads = 1,这里sr_len_q可以不等于sr_len_k,不妨假设 sr_len_q = 5,sr_len_k = 6,因为 mask 只针对key,因此这里只需要关注 source sentence 中的padding, 假设 source sentence 的最后两个位置是padding,那么Masked Multihead Attention 中的 mask 为:
(
1
1
1
1
0
0
1
1
1
1
0
0
1
1
1
1
0
0
1
1
1
1
0
0
1
1
1
1
0
0
)
\begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \end{pmatrix}
???????11111?11111?11111?11111?00000?00000????????
- Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
batch_size
×
num_heads
×
tar_len_q
×
hidden_size
num_heads
?
reshape
batch_size
×
tar_len_q
×
hidden_size
?
线性变换
batch_size
×
tar_len_q
×
hidden_size
\begin{aligned} \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \stackrel{\text{reshape}}{\Downarrow}\\ \text{batch\_size} \times \text{tar\_len\_q} &\times \text{hidden\_size}\\ \stackrel{\text{线性变换}}{\Downarrow}\\ \text{batch\_size} \times \text{tar\_len\_q} &\times \text{hidden\_size} \end{aligned}
batch_size×num_heads?reshape?batch_size×tar_len_q?线性变换?batch_size×tar_len_q?×tar_len_q×num_headshidden_size?×hidden_size×hidden_size?
预测阶段
Encoder Multihead Attention
与训练阶段的 Encoder Multihead Attention 完全相同
Masked Multihead Attention
虽然在训练阶段,Masked Multihead Attention 会将当前 step 之后的 steps 信息都 mask 掉,但是由于训练时整个 target sentence 都是已知的,因此还是可以做并行运算的。
但是在预测阶段,初始的 query, key, value 都只是一个 “<bos>” 起始符号,之后每预测出一个 token,这个 token 直接作为下一个 step 输入的 query,而将这个 token 拼在现有的 key 和 value 之后,就是下一个 step 输入的 key 和 value。也就是说,预测阶段每个 step 输入的 query 是上一 step 输出的token,而 key, value 是之前所有 steps 输出的token。
至于 mask 的部分,由于输入中不再含有未来 steps 的信息,因此不再需要用 mask 来消除这部分信息。而对于 key mask,由于 Masked Multihead Attention 的 key 是 target sentence,而在预测完成前 target sentence 的长度是未知的,因此针对 key 的 mask 也是不需要的。 也就是说,Masked Multihead Attention 是不需要 mask 的。
下面是预测阶段 Masked Multihead Attention 的流程:
-
Input: key, value 是到当前 step 为止的所有 steps 的信息,大小为 1
×
\times
× cur_tar_len
×
\times
× hidden_size;而 query 是上一 step 的输出 token,大小为 1
×
\times
× 1
×
\times
× hidden_size:
q
u
e
r
y
:
1
×
1
×
hidden_size
?
线性变换
1
×
1
×
hidden_size
?
reshape
1
×
num_heads
×
1
×
hidden_size
num_heads
k
e
y
:
1
×
cur_tar_len_k
×
hidden_size
?
线性变换
1
×
cur_tar_len_k
×
hidden_size
?
reshape
1
×
num_heads
×
cur_tar_len_k
×
hidden_size
num_heads
v
a
l
u
e
:
1
×
cur_tar_len_v
×
hidden_size
?
线性变换
1
×
cur_tar_len_v
×
hidden_size
?
reshape
1
×
num_heads
×
cur_tar_len_v
×
hidden_size
num_heads
\begin{aligned} \boldsymbol {query}&: \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}&: \text{1} \times \text{cur\_tar\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{cur\_tar\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{cur\_tar\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}&: \text{1} \times \text{cur\_tar\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{cur\_tar\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{cur\_tar\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned}
querykeyvalue?:1×1×hidden_size?线性变换?1×1×hidden_size?reshape?1×num_heads×1×num_headshidden_size?:1×cur_tar_len_k×hidden_size?线性变换?1×cur_tar_len_k×hidden_size?reshape?1×num_heads×cur_tar_len_k×num_headshidden_size?:1×cur_tar_len_v×hidden_size?线性变换?1×cur_tar_len_v×hidden_size?reshape?1×num_heads×cur_tar_len_v×num_headshidden_size?? -
DotProductAttention: num_heads 个头的计算依然可以并行:
q
u
e
r
y
?
k
e
y
T
=
1
×
num_heads
×
1
×
cur_tar_len_k
?
softmax
(
q
u
e
r
y
?
k
e
y
T
)
=
1
×
num_heads
×
1
×
cur_tar_len_k
?
softmax
(
q
u
e
r
y
?
k
e
y
T
)
?
v
a
l
u
e
=
1
×
num_heads
×
1
×
hidden_size
num_heads
\begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{cur\_tar\_len\_k}\\ \Downarrow\\ \text{softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{cur\_tar\_len\_k}\\ \Downarrow\\ \text{softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned}
query?keyT=1×num_heads?softmax(query?keyT)=1×num_heads?softmax(query?keyT)?value=1×num_heads?×1×cur_tar_len_k×1×cur_tar_len_k×1×num_headshidden_size?? -
Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
1
×
num_heads
×
1
×
hidden_size
num_heads
?
reshape
1
×
1
×
hidden_size
?
线性变换
1
×
1
×
hidden_size
\begin{aligned} \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ &\stackrel{\text{reshape}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size}\\ &\stackrel{\text{线性变换}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size} \end{aligned}
1×num_heads1×11×1?×1×num_headshidden_size??reshape?×hidden_size?线性变换?×hidden_size?
Encoder-Decoder Multihead Attention
预测阶段 Encoder-Decoder Multihead Attention 输入的 query 是上一层 Masked Multihead Attention 的输出,大小为 1
×
\times
× 1
×
\times
× hidden_size。而输入的 key 和 value 则是 encoder layer 的输出,大小为:1
×
\times
× sr_len
×
\times
× hidden_size。具体流程为:
-
Input: :
q
u
e
r
y
:
1
×
1
×
hidden_size
?
线性变换
1
×
1
×
hidden_size
?
reshape
1
×
num_heads
×
1
×
hidden_size
num_heads
k
e
y
:
1
×
sr_len_k
×
hidden_size
?
线性变换
1
×
sr_len_k
×
hidden_size
?
reshape
1
×
num_heads
×
sr_len_k
×
hidden_size
num_heads
v
a
l
u
e
:
1
×
sr_len_v
×
hidden_size
?
线性变换
1
×
sr_len_v
×
hidden_size
?
reshape
1
×
num_heads
×
sr_len_v
×
hidden_size
num_heads
\begin{aligned} \boldsymbol {query}&: \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}&: \text{1} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}&: \text{1} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned}
querykeyvalue?:1×1×hidden_size?线性变换?1×1×hidden_size?reshape?1×num_heads×1×num_headshidden_size?:1×sr_len_k×hidden_size?线性变换?1×sr_len_k×hidden_size?reshape?1×num_heads×sr_len_k×num_headshidden_size?:1×sr_len_v×hidden_size?线性变换?1×sr_len_v×hidden_size?reshape?1×num_heads×sr_len_v×num_headshidden_size?? -
DotProductAttention: num_heads 个头的计算并行:
q
u
e
r
y
?
k
e
y
T
=
1
×
num_heads
×
1
×
sr_len_k
?
消
除
k
e
y
中
padding
的
影
响
,
对
其
做
mask
masked_softmax
(
q
u
e
r
y
?
k
e
y
T
)
=
1
×
num_heads
×
1
×
sr_len_k
?
masked_softmax
(
q
u
e
r
y
?
k
e
y
T
)
?
v
a
l
u
e
=
1
×
num_heads
×
1
×
hidden_size
num_heads
\begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned}
query?keyT=1×num_heads?消除key中padding的影响,对其做mask?masked_softmax(query?keyT)=1×num_heads?masked_softmax(query?keyT)?value=1×num_heads?×1×sr_len_k×1×sr_len_k×1×num_headshidden_size??
假设 num_heads = 1,sr_len_k = 6,因为 mask 只针对key,因此这里只需要关注 source sentence 中的padding, 假设 source sentence 的最后两个位置是padding,那么Masked Multihead Attention 中的 mask 为:
(
1
1
1
1
0
0
)
\begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0 \\ \end{pmatrix}
(1?1?1?1?0?0?)
- Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
1
×
num_heads
×
1
×
hidden_size
num_heads
?
reshape
1
×
1
×
hidden_size
?
线性变换
1
×
1
×
hidden_size
\begin{aligned} \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ &\stackrel{\text{reshape}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size}\\ &\stackrel{\text{线性变换}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size} \end{aligned}
1×num_heads1×11×1?×1×num_headshidden_size??reshape?×hidden_size?线性变换?×hidden_size?
由于最后的 Feed Forward 层不改变矩阵大小,至此可以总结一下预测阶段的 Decoder layer,输入是上一 step 输出的 token,大小为 1
×
\times
× 1
×
\times
× hidden_size,经过两种MultiHead + Feed Forward 后,大小依然为 1
×
\times
× 1
×
\times
× hidden_size,再经过 Linear+Softmax,其输出就是预测的当前 step 的token,而这个 token 又会作为下一个 step 的输入 query。直到达到最大长度,或者输出的 token 是 “<eos>”
|