一. 基于注意力机制的seq2seq
1. Bahdanau 注意力介绍
在前面博客李沐动手学深度学习V2-seq2seq和代码实现中探讨了基于seq2seq架构的机器翻译问题: 通过设计一个基于两个循环神经网络的编码器-解码器架构, 用于序列到序列学习。 具体来说循环神经网络编码器将长度可变的序列转换为固定形状的上下文变量, 然后循环神经网络解码器根据生成的词元和上下文变量按词元生成输出(目标)序列词元。 然而即使并非所有输入(源)词元都对解码某个词元都有用, 在每个解码步骤中仍使用编码相同的上下文变量,因此本节解决的是用什么方法能改变上下文变量呢?使用基于注意力机制的输出来改变上下文变量。 Bahdanau等人提出了一个没有严格单向对齐限制的可微注意力模型 ,在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分,这是通过将上下文变量视为注意力集中的输出来实现的。
2. 模型
下面描述的Bahdanau注意力模型与前面李沐动手学深度学习V2-seq2seq和代码实现seq2seq中的模型架构相同,只不过将前面seq2seq中在任何解码时间步
t
′
t'
t′的上下文变量
c
\mathbf{c}
c都会被基于注意力机制的
c
t
′
\mathbf{c}_{t'}
ct′?替换。 假设输入序列中有
T
T
T个词元,解码时间步
t
′
t'
t′的上下文变量是注意力集中的输出:
c
t
′
=
∑
t
=
1
T
α
(
s
t
′
?
1
,
h
t
)
h
t
,
\mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t,
ct′?=t=1∑T?α(st′?1?,ht?)ht?, 其中,时间步
t
′
?
1
t' - 1
t′?1时的解码器隐状态
s
t
′
?
1
\mathbf{s}_{t' - 1}
st′?1?是查询,编码器隐状态
h
t
\mathbf{h}_t
ht?既是键,也是值,注意力权重
α
\alpha
α是使用 前面博文李沐动手学深度学习V2-注意力评分函数所定义的加性注意力打分函数计算的。 与 前面李沐动手学深度学习V2-seq2seq和代码实现seq2seq中的循环神经网络编码器-解码器架构略有不同,下面描述了Bahdanau注意力的架构。
3. 定义注意力解码器
由于编码器跟之前seq2seq相同,因此只需重新定义解码器。 为了显示学习的注意力权重, 下面AttentionDecoder类定义了带有注意力机制解码器的基本接口。
import torch
import d2l.torch
from torch import nn
class AttentionDecoder(d2l.torch.Decoder):
"""带有注意力机制解码器的基本接口"""
def __init__(self):
super(AttentionDecoder,self).__init__()
@property
def attention_weights(self):
raise NotImplementedError
接下来在Seq2SeqAttentionDecoder类中实现带有Bahdanau注意力的循环神经网络解码器。 首先,初始化解码器的状态,需要下面的输入:
- 编码器在所有时间步的最后一层隐状态,将作为注意力的键和值;
- 编码器最后一个时间步所有层的隐状态,将作为初始化解码器所有层的隐状态;
- 编码器有效长度(在注意力池中排除填充词元,不使用填充词元作为attention权重)。
在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询,编码器在所有时间步的最后一层隐状态,将作为注意力的键和值。因此注意力输出和当前时间步目标序列输入嵌入连接起来作为循环神经网络解码器的输入。
class Seq2SeqAttentionDecoder(AttentionDecoder):
def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0):
super(Seq2SeqAttentionDecoder,self).__init__()
self.embedding = nn.Embedding(vocab_size,embed_size)
self.rnn = nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)
self.dense = nn.Linear(num_hiddens,vocab_size)
self.additiveAttention = d2l.torch.AdditiveAttention(key_size=num_hiddens,query_size=num_hiddens,num_hiddens=num_hiddens,dropout=dropout)
def init_state(self, enc_outputs, enc_valid_lens,*args):
outputs,hidden_states = enc_outputs
return (outputs.permute(1,0,2),hidden_states,enc_valid_lens)
def forward(self,X,states):
enc_outputs,hidden_states,enc_valid_lens = states
X = self.embedding(X).permute(1,0,2)
outputs = []
self._attention_weights = []
for x in X:
query = torch.unsqueeze(hidden_states[-1],dim=1)
context = self.additiveAttention(queries=query,keys=enc_outputs,values=enc_outputs,valid_lens=enc_valid_lens)
x_context = torch.cat((context,torch.unsqueeze(x,dim=1)),dim=-1)
output,hidden_states = self.rnn(x_context.permute(1,0,2),hidden_states)
outputs.append(output)
self._attention_weights.append(self.additiveAttention.attention_weights)
outputs = self.dense(torch.cat(outputs,dim=0))
return outputs.permute(1,0,2),[enc_outputs,hidden_states,enc_valid_lens]
@property
def attention_weights(self):
return self._attention_weights
使用包含7个时间步的4个序列输入的小批量测试Bahdanau注意力解码器。
encoder = d2l.torch.Seq2SeqEncoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
decoder.eval()
X = torch.zeros(size=(4,7),dtype=torch.long)
states = decoder.init_state(encoder(X),None)
outputs,states = decoder(X,states)
outputs.shape,len(states),states[0].shape,len(states[1]),states[1].shape,states[1][0].shape
输出结果如下:
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))
4. 训练
指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器, 并对这个模型进行机器翻译训练。 由于新增的注意力机制,训练要比没有注意力机制的seq2seq慢得多。
embed_size,num_hiddens,num_layers,dropout = 32,32,2,0.1
batch_size,num_steps = 64,10
lr=0.005
device = d2l.torch.try_gpu()
num_epochs = 250
train_iter,src_vocab,tgt_vocab = d2l.torch.load_data_nmt(batch_size,num_steps)
encoder = d2l.torch.Seq2SeqEncoder(len(src_vocab),embed_size,num_hiddens,num_layers,dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab),embed_size,num_hiddens,num_layers,dropout)
net = d2l.torch.EncoderDecoder(encoder,decoder)
d2l.torch.train_seq2seq(net,train_iter,lr,num_epochs,tgt_vocab,device)
模型训练完后,使用它将几个英语句子翻译成法语并计算它们的BLEU分数。
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng,fra in zip(engs,fras):
translations,dec_attention_weights = d2l.torch.predict_seq2seq(net,eng,src_vocab,tgt_vocab,num_steps,device,save_attention_weights=True)
print('eng:',eng,'==>','translations:',translations,'==> bleu:',d2l.torch.bleu(translations,fra,k=2))
输出结果如下:
eng: go . ==> translations: va ! ==> bleu: 1.0
eng: i lost . ==> translations: j'ai perdu . ==> bleu: 1.0
eng: he's calm . ==> translations: il est bon . ==> bleu: 0.6580370064762462
eng: i'm home . ==> translations: je suis chez moi . ==> bleu: 1.0
训练结束后,下面通过可视化注意力权重发现,每个查询都会在键值对上分配不同的权重,这说明在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weights],dim=0).reshape(1,1,-1,num_steps)
d2l.torch.show_heatmaps(attention_weights[:,:,:,:len(engs[-1].split())+1].cpu(),xlabel='key positions',ylabel='query positions')
5. 小结
- 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。
- 在基于注意力机制的循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步最后一层的解码器隐状态视为查询,将编码器在所有时间步最后一层的隐状态同时视为键和值。
6. 全部代码
import torch
import d2l.torch
from torch import nn
class AttentionDecoder(d2l.torch.Decoder):
"""带有注意力机制解码器的基本接口"""
def __init__(self):
super(AttentionDecoder, self).__init__()
@property
def attention_weights(self):
raise NotImplementedError
class Seq2SeqAttentionDecoder(AttentionDecoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
super(Seq2SeqAttentionDecoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Linear(num_hiddens, vocab_size)
self.additiveAttention = d2l.torch.AdditiveAttention(key_size=num_hiddens, query_size=num_hiddens,
num_hiddens=num_hiddens, dropout=dropout)
def init_state(self, enc_outputs, enc_valid_lens, *args):
outputs, hidden_states = enc_outputs
return (outputs.permute(1, 0, 2), hidden_states, enc_valid_lens)
def forward(self, X, states):
enc_outputs, hidden_states, enc_valid_lens = states
X = self.embedding(X).permute(1, 0, 2)
outputs = []
self._attention_weights = []
for x in X:
query = torch.unsqueeze(hidden_states[-1], dim=1)
context = self.additiveAttention(queries=query, keys=enc_outputs, values=enc_outputs,
valid_lens=enc_valid_lens)
x_context = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
output, hidden_states = self.rnn(x_context.permute(1, 0, 2), hidden_states)
outputs.append(output)
self._attention_weights.append(self.additiveAttention.attention_weights)
outputs = self.dense(torch.cat(outputs, dim=0))
return outputs.permute(1, 0, 2), [enc_outputs, hidden_states, enc_valid_lens]
@property
def attention_weights(self):
return self._attention_weights
encoder = d2l.torch.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
X = torch.zeros(size=(4, 7), dtype=torch.long)
states = decoder.init_state(encoder(X), None)
outputs, states = decoder(X, states)
outputs.shape, len(states), states[0].shape, len(states[1]), states[1].shape, states[1][0].shape
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr = 0.005
device = d2l.torch.try_gpu()
num_epochs = 250
train_iter, src_vocab, tgt_vocab = d2l.torch.load_data_nmt(batch_size, num_steps)
encoder = d2l.torch.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.torch.EncoderDecoder(encoder, decoder)
d2l.torch.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translations, dec_attention_weights = d2l.torch.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device,
save_attention_weights=True)
print('eng:', eng, '==>', 'translations:', translations, '==> bleu:', d2l.torch.bleu(translations, fra, k=2))
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weights], dim=0).reshape(1, 1, -1, num_steps)
dec_attention_weights #翻译最后一个序列得到的注意力分数,包含序列结束词元,因此总共需要六个预测词元,所以有六个注意力分数权重
attention_weights
# 加上一个包含序列结束词元
d2l.torch.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(), xlabel='key positions',
ylabel='query positions')
7. 相关链接
注意力机制第一篇:李沐动手学深度学习V2-注意力机制 注意力机制第二篇:李沐动手学深度学习V2-注意力评分函数 注意力机制第三篇:李沐动手学深度学习V2-基于注意力机制的seq2seq 注意力机制第四篇:李沐动手学深度学习V2-自注意力机制之位置编码 注意力机制第五篇:李沐动手学深度学习V2-自注意力机制 注意力机制第六篇:李沐动手学深度学习V2-多头注意力机制和代码实现 机器翻译第一篇:李沐动手学深度学习V2-机器翻译和数据集 机器翻译第二篇:李沐动手学深度学习V2-Encoder-Decoder编码器和解码器架构 机器翻译第三篇:李沐动手学深度学习V2-seq2seq和代码实现
|