引言
本文通过Pytorch实现了Seq2Seq中常用的注意力方式。
注意力方式
s
c
o
r
e
(
h
t
,
h
 ̄
s
)
=
{
h
t
T
h
 ̄
s
dot
h
t
T
W
a
h
 ̄
s
general
v
a
T
tanh
?
(
W
a
[
h
t
;
h
 ̄
s
]
)
concat
v
a
T
tanh
?
(
W
a
h
 ̄
s
+
U
a
h
t
)
bahdanau
score(h_t, \overline{h}_s) = \begin{cases} h_t^T \overline{h}_s & \text{dot} \\ h_t^T W_a \overline{h}_s & \text{general} \\ v_a^T \tanh (W_a[h_t; \overline{h}_s]) & \text{concat} \\ v_a^T \tanh (W_a\overline{h}_s + U_a h_t) & \text{bahdanau} \end{cases}
score(ht?,hs?)=??????????htT?hs?htT?Wa?hs?vaT?tanh(Wa?[ht?;hs?])vaT?tanh(Wa?hs?+Ua?ht?)?dotgeneralconcatbahdanau?
结合论文Effective Approaches to Attention-based Neural Machine Translation和NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE,我们得到上面四种计算注意力的方式。
编码器的每个输出
h
i
h_i
hi?对应的权重
α
i
j
\alpha_{ij}
αij?通过如下公式计算:
α
i
j
=
e
x
p
(
e
i
j
)
∑
k
=
1
T
x
e
x
p
(
e
i
k
)
(6)
\alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x} exp(e_{ik})} \tag{6}
αij?=∑k=1Tx??exp(eik?)exp(eij?)?(6) 其中
e
i
j
=
a
(
s
i
?
1
,
h
j
)
e_{ij} = a(s_{i-1},h_j)
eij?=a(si?1?,hj?)
见(论文翻译) NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE
代码实现
import torch.nn as nn
import torch
class Attention(nn.Module):
def __init__(self, hidden_size, method='dot'):
super(Attention, self).__init__()
self.method = method
self.hidden_size = hidden_size
if self.method not in ['dot', 'general', 'concat', 'bahdanau']:
raise ValueError(self.method, "is not an appropriate attention method.")
if self.method == 'general':
self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
elif self.method == 'concat':
self.Wa = nn.Linear(hidden_size * 2, hidden_size, bias=False)
self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))
elif self.method == 'bahdanau':
self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))
def _score(self, last_hidden, encoder_outputs):
'''
:param last_hidden: 解码器最后一层(若有多层的话)的输出 [1,batch_size,hidden_size] 解码器一次只处理一个时间步,并且只有一个方向: D=1
:param encoder_outputs: 编码器所有时间步的隐藏状态 [seq_len, batch_size, hidden_size]
'''
if self.method == 'dot':
return torch.sum(last_hidden * encoder_outputs, dim=2)
elif self.method == 'general':
energy = self.Wa(last_hidden)
return torch.sum(encoder_outputs * energy, dim=2)
elif self.method == 'concat':
energy = torch.tanh(
self.Wa(torch.cat((encoder_outputs, last_hidden.expand(encoder_outputs.size(0), -1, -1)), dim=2)))
return torch.sum(self.va * energy, dim=2)
else:
energy = torch.tanh(self.Wa(last_hidden) + self.Ua(encoder_outputs))
return torch.sum(self.va * energy, dim=2)
def forward(self, last_hidden, encoder_outputs):
attn_energies = self._score(last_hidden, encoder_outputs)
attn_energies = attn_energies.t()
return torch.softmax(attn_energies, dim=1) \
.unsqueeze(1)
|