对Transformer中Positional Encoding的理解
Transformer是最新的处理序列到序列问题的架构,由self-attention组成,其优良的可并行性以及可观的表现提升,让它在NLP领域中大受欢迎,GPT-3以及BERT都是基于Transformer实现的。
刚开始学Transformer,对一个模块Positional Encoding存在一些疑问,因此,参考了一些资料和博客,学习如何理解Positional Encoding。
1. 什么是Positional Encoding?为什么Transformer需要使用Positional Encoding?
在任何一门语言中,词语的位置和顺序对句子意思表达都是至关重要的。传统的RNN模型在处理句子时,以序列的模式逐个处理句子中的词语,这使得词语的顺序信息在处理过程中被天然的保存下来了,并不需要额外的处理。
而对于Transformer来说,由于句子中的词语都是同时进入网络进行处理,顺序信息在输入网络时就已丢失。因此,Transformer是需要额外的处理来告知每个词语的相对位置的。其中的一个解决方案,就是论文中提到的Positional Encoding,将能表示位置信息的编码添加到输入中,让网络知道每个词的位置和顺序。
一句话概括,Positional Encoding就是句子中词语相对位置的编码,让Transformer保留词语的位置信息。
Transformer的输入
首先给出Transformer的输入部分,如上图所示。X:[batch size,sequence length] 指的是初始输入的多语句矩阵,多语句矩阵通过查表 ,得到词向量矩阵
X
e
m
b
e
d
d
i
n
g
X_{embedding}
Xembedding?:[batch size,sequence length,embedding dimension] 。batch size指的是句子数,sequence length指的是输入的句子中最长的句子的字数,embedding dimension指的是词向量的长度(通过查表得到)。
X
X
X和
X
e
m
b
e
d
d
i
n
g
X_{embedding}
Xembedding?的示意图如下图所示:
2. Positional Encoding是怎么做的?
要表示位置信息,首先出现在脑海里的一个点子可能是,给句子中的每个词赋予一个相位,也就是[0, 1]中间的一个值,第一个词是0,最后一个词是1,中间的词在0到1之间取值。
但是这样会不会有什么问题呢?其中一个问题在于,你并不知道每个句子中词语的个数是多少,这会导致每个词语之间的间隔变化是不一致的。而对于一个句子来说,每个词语之间的间隔都应该是具有相同含义的。
那,为了保证每个词语的间隔含义一致,我们是不是可以给每个词语添加一个线性增长的时间戳呢?比如说第一个词是0,第二词是1,以此类推,第N个词的位置编码是N。
这样其实也会有问题。同样,我们并不知道一个句子的长度,如果训练的句子很长的话,这样的编码是不合适的。 另外,这样训练出来的模型,在泛化性上是有一定问题的。
因此,理想情况下,编码方式应该要满足以下几个条件:
- 对于每个位置的词语,它都能提供一个独一无二的编码
- 词语之间的间隔对于不同长度的句子来说,含义应该是一致的
- 能够随意延申到任意长度的句子
文中提出了一种简单且有效的编码方式,能够满足上述所有条件。
公式表达
其中,PE为二维矩阵,大小跟输入embedding的维度一样,行表示词语,列表示词向量;
p
o
s
pos
pos表示词语在句子中的位置;
d
m
o
d
e
l
d_{model}
dmodel?表示词向量的维度;
i
i
i表示词向量的位置。因此,上述公式表示在每个词语的词向量的偶数位置添加sin变量,奇数位置添加cos变量,以此来填满整个PE矩阵,然后加到input embedding中去,这样便完成了位置编码的引入,
使用sin编码和cos编码的原因是可以得到词语之间的相对位置,因为:
sin
?
(
α
+
β
)
=
sin
?
α
cos
?
β
+
cos
?
α
sin
?
β
\sin{(\alpha+\beta)} = \sin{\alpha}\cos{\beta}+\cos{\alpha}\sin{\beta}
sin(α+β)=sinαcosβ+cosαsinβ
cos
?
(
α
+
β
)
=
cos
?
α
cos
?
β
?
sin
?
α
sin
?
β
\cos{(\alpha+\beta)} = \cos{\alpha}\cos{\beta} - \sin{\alpha}\sin{\beta}
cos(α+β)=cosαcosβ?sinαsinβ
即由
sin
?
(
p
o
s
+
k
)
\sin(pos+k)
sin(pos+k)可以得到,通过线性变换获取后续词语相对当前词语的位置关系。
源码展示
class PositionalEncoding(nn.Module):
"Implement the PE function."
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0., d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
return self.dropout(x)
其中,div_term是上述公式经过简单的数学变换得到的,具体如下:
1
/
1000
0
2
i
/
d
m
o
d
e
l
=
e
l
o
g
(
1000
0
?
2
i
/
d
m
o
d
e
l
)
=
e
?
2
i
/
d
m
o
d
e
l
?
l
o
g
10000
=
e
2
i
?
(
?
l
o
g
10000
/
d
m
o
d
e
l
)
1/10000^{2i/d_{model}}=e^{log{(10000^{-2i/d_{model}})}}=e^{-2i/d_{model} * log{10000}}=e^{2i * (-log{10000}/d_{model})}
1/100002i/dmodel?=elog(10000?2i/dmodel?)=e?2i/dmodel??log10000=e2i?(?log10000/dmodel?)
直观理解
为什么这样简单的sines和cosines的组合可以表达位置信息呢?一开始的确有点难理解。举个二进制的例子就明白了。可以观察一下下面这个表,将数字用二进制表示出来。可以发现,每个比特位的变化率是不一样的,越低位的变化越快,红色位置0和1每个数字会变化一次,而黄色位,每8个数字才会变化一次。 不同频率的sines和cosines组合其实也是同样的道理,通过调整三角函数的频率,可以实现这种低位到高位的变化,这样的话,位置信息就表示出来了。
计算过程
如上图所示,word embedding指的是词向量由每个词根据查表得到,pos embedding就是我们要求的Positional Encoding,也就是位置编码。可以看到word embedding和pos embedding逐点相加得到composition,即包含语义信息和位置编码信息的最终矩阵。
回到公式中,我们可以得知:
p
o
s
pos
pos指当前字符在句子中的位置(如:“你好啊”,这句话里面“你”的
p
o
s
=
0
pos=0
pos=0),
d
m
o
d
e
l
d_{model}
dmodel?指的是word embedding的长度(比如说:查表得到“民主”这个词的word embedding为
[
1
,
2
,
3
,
4
,
5
]
[1,2,3,4,5]
[1,2,3,4,5],则
d
m
o
d
e
l
=
5
d_{model}=5
dmodel?=5),
i
i
i的取值范围是:
i
=
0
,
1
,
.
.
.
,
d
m
o
d
e
l
?
1
i=0,1,...,d_{model}-1
i=0,1,...,dmodel??1。当
i
i
i的值为偶数时使用上面那条公式,当
i
i
i的值为奇数时使用下面那条公式。当
p
o
s
=
3
,
d
m
o
d
e
l
=
128
pos=3, d_{model}=128
pos=3,dmodel?=128时Positional Encoding(或者说是pos embedding)的计算结果为: 每一个字所计算出来的Positional Encoding并不是一个值而是一个向量,它的长度和这个字的word embedding的长度一致,从而方便它们两个逐点相加得到既包含word embedding又包含位置信息的最终向量。
参考资料
1.https://zhuanlan.zhihu.com/p/338592312 2.https://blog.csdn.net/weixin_44012382/article/details/113059423 3.https://arxiv.org/pdf/1706.03762.pdf
|