IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于注意力机制的seq2seq模型 -> 正文阅读

[人工智能]基于注意力机制的seq2seq模型

一、前言

在此之前,我们实现了最普通的seq2seq模型,该模型的编码器和解码器均采用的是两层单向的GRU。本篇文章将基于注意力机制改进之前的seq2seq模型,其中编码器采用两层双向的LSTM,解码器采用含有注意力机制的两层单向LSTM。由于数据预处理部分相同,因此本文不再赘述,详情可参考之前的文章。

二、模型搭建

本文接下来的叙述将沿用这篇文章中的符号。

2.1 编码器

编码器我们采用两层双向LSTM。编码器的输入形状为 ( N , L ) (N,L) (N,L),输出 output 的形状为 ( L , N , 2 h ) (L,N,2h) (L,N,2h),它是正向LSTM和反向LSTM输出进行了concat后的结果,包含了正反向的信息。编码器输出的 h_nc_n 的形状均为 ( 2 n , N , h ) (2n,N,h) (2n,N,h),需要将其形状改变为 ( n , N , 2 h ) (n,N,2h) (n,N,2h) 后才可作为解码器的初始隐状态。

至于为什么要改变 h_nc_n 的形状以及为什么不能直接用 reshape 去改变会在后面提到。

编码器的实现如下:

class Seq2SeqEncoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=1)
        self.rnn = nn.LSTM(emb_size, hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=True)

    def forward(self, encoder_inputs):
        encoder_inputs = self.embedding(encoder_inputs).permute(1, 0, 2)
        output, (h_n, c_n) = self.rnn(encoder_inputs)  # output shape: (seq_len, batch_size, 2 * hidden_size)
        h_n = torch.cat((h_n[::2], h_n[1::2]), dim=2)  # (num_layers, batch_size, 2 * hidden_size)
        c_n = torch.cat((c_n[::2], c_n[1::2]), dim=2)  # (num_layers, batch_size, 2 * hidden_size)
        return output, h_n, c_n

2.2 注意力机制

在原先的seq2seq模型中,解码器在每一个时间步所使用的上下文向量均相同。现在我们希望解码器在不同的时间步上能够注意到源序列中不同的信息,因此考虑采用注意力机制。

解码器的核心架构为两层单向的LSTM(只能是单向),在 t t t 时刻,我们采用解码器在 t ? 1 t-1 t?1 时刻最后一个隐层的输出作为查询,每个 output[t] 既作为键也作为值,相应的计算上下文向量的公式如下:

context [ t ] = ∑ t = 1 L α ( decoder_state [ t ? 1 ] , output [ t ] ) ? output [ t ] \text{context}[t]=\sum_{t=1}^L \alpha(\text{decoder\_state}[t-1], \text{output}[t])\cdot \text{output}[t] context[t]=t=1L?α(decoder_state[t?1],output[t])?output[t]

其中 α ( q , k ) \alpha(q,k) α(q,k) 是注意力权重。

假设编码器所采用的LSTM的隐层大小为 h h h,解码器所采用的LSTM的隐层大小为 h ′ h' h。因 output[t] 的形状为 ( N , 2 h ) (N,2h) (N,2h)decoder_state[t - 1] 的形状为 ( N , h ′ ) (N,h') (N,h),要使用缩放点积注意力,则必须有 h ′ = 2 h h'=2h h=2h,否则无法进行内积操作,所以可以得出:解码器隐层大小是编码器的两倍

注意力机制实现如下:

class AttentionMechanism(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, decoder_state, encoder_output):
        # 解码器的隐藏层大小必须是编码器的两倍,否则无法进行接下来的内积操作
        # decoder_state shape: (batch_size, 2 * hidden_size)
        # encoder_output shape: (seq_len, batch_size, 2 * hidden_size)
        decoder_state = decoder_state.unsqueeze(1)  # (batch_size, 1, 2 * hidden_size)
        encoder_output = encoder_output.transpose(0, 1)  # (batch_size, seq_len, 2 * hidden_size)
        # scores shape: (batch_size, seq_len)
        scores = torch.sum(decoder_state * encoder_output, dim=-1) / math.sqrt(decoder_state.shape[2])  # 广播机制
        attn_weights = F.softmax(scores, dim=-1)
        # context shape: (batch_size, 2 * hidden_size)
        context = torch.sum(attn_weights.unsqueeze(-1) * encoder_output, dim=1)  # 广播机制
        return context

2.3 解码器

解码器的原始输入的形状为 ( N , L ) (N,L) (N,L),通过嵌入以及 permute 操作后其形状变为 ( L , N , d ) (L,N,d) (L,N,d),而上下文的形状为 ( N , 2 h ) (N,2h) (N,2h) ( L , N , d ) (L,N,d) (L,N,d) ( N , 2 h ) (N,2h) (N,2h) 进行concat后才会作为其内部LSTM的输入。因此解码器采用的LSTM的 input_size d + 2 h d+2h d+2h。为了保证注意力机制正常运作,其隐藏层大小也应为编码器的两倍,即 2 h 2h 2h

从编码器我们得到了形状为 ( 2 n , N , h ) (2n,N,h) (2n,N,h)h_n,而解码器采用的LSTM是单向的,从而其接受的 h_0 的形状应为 ( n , N , 2 h ) (n,N,2h) (n,N,2h)。一个很自然的想法是直接使用 reshape 完成形状的转化,但这样做会带来一个问题,即无法保证 h_0[-1] 对应的是正反向编码器在最后一个时间步最后一个隐层的输出的拼接,为此可考虑采用如下方式解决:

h 0 = Concat ( ( h n [ : ? : 2 ] , h n [ 1 : ? : 2 ] ) , ?? dim = 2 ) h_0 =\text{Concat}((h_n [ : \, : 2],h_n [ 1: \, : 2]),\;\text{dim} = 2) h0?=Concat((hn?[::2],hn?[1::2]),dim=2)

至于为什么这样做,可以参考这篇文章

在评估阶段中,我们往往需要利用模型的解码器一步一步地输出,每一时刻都会利用上一时刻解码器输出的隐状态,类似于下面的伪代码:

decoder_output, hidden_state = decoder(decoder_input, hidden_state)

这要求输入到解码器中的隐状态和解码器输出的隐状态的形状必须相同因此 h_nc_n 的形状转化必须在编码器中完成

解码器的实现:

class Seq2SeqDecoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=1)
        self.attn = AttentionMechanism()
        self.rnn = nn.LSTM(emb_size + 2 * hidden_size, 2 * hidden_size, num_layers=num_layers, dropout=dropout)
        self.fc = nn.Linear(2 * hidden_size, vocab_size)

    def forward(self, decoder_inputs, encoder_output, h_n, c_n):
        decoder_inputs = self.embedding(decoder_inputs).permute(1, 0, 2)  # (seq_len, batch_size, emb_size)
        # 注意将其移动到GPU上
        decoder_output = torch.zeros(decoder_inputs.shape[0], *h_n.shape[1:]).to(device)  # (seq_len, batch_size, 2 * hidden_size)
        for i in range(len(decoder_inputs)):
            context = self.attn(h_n[-1], encoder_output)  # (batch_size, 2 * hidden_size)
            # single_step_output shape: (1, batch_size, 2 * hidden_size)
            single_step_output, (h_n, c_n) = self.rnn(torch.cat((decoder_inputs[i], context), -1).unsqueeze(0), (h_n, c_n))
            decoder_output[i] = single_step_output.squeeze()
        logits = self.fc(decoder_output)  # (seq_len, batch_size, vocab_size)
        return logits, h_n, c_n

2.4 Seq2Seq模型

整体架构如下:

只需要将编码器和解码器封装在一起即可:

class Seq2SeqModel(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, encoder_inputs, decoder_inputs):
        return self.decoder(decoder_inputs, *self.encoder(encoder_inputs))

三、模型的训练与评估

因为输入输出发生了一些变化,我们只需要对原先的 train 函数和 evaluate 函数稍作修改

def train(train_loader, model, criterion, optimizer, num_epochs):
    train_loss = []
    model.train()
    for epoch in range(num_epochs):
        for batch_idx, (encoder_inputs, decoder_targets) in enumerate(train_loader):
            encoder_inputs, decoder_targets = encoder_inputs.to(device), decoder_targets.to(device)
            bos_column = torch.tensor([tgt_vocab['<bos>']] * decoder_targets.shape[0]).reshape(-1, 1).to(device)
            decoder_inputs = torch.cat((bos_column, decoder_targets[:, :-1]), dim=1)
            pred, _, _ = model(encoder_inputs, decoder_inputs)
            loss = criterion(pred.permute(1, 2, 0), decoder_targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            if (batch_idx + 1) % 50 == 0:
                print(
                    f'[Epoch {epoch + 1}] [{(batch_idx + 1) * len(encoder_inputs)}/{len(train_loader.dataset)}] loss: {loss:.4f}')
        print()
    return train_loss


def evaluate(test_loader, model, bleu_k):
    bleu_scores = []
    translation_results = []
    model.eval()
    for src_seq, tgt_seq in test_loader:
        encoder_inputs = src_seq.to(device)
        encoder_output, h_n, c_n = model.encoder(encoder_inputs)
        pred_seq = [tgt_vocab['<bos>']]
        for _ in range(SEQ_LEN):
            decoder_inputs = torch.tensor(pred_seq[-1]).reshape(1, 1).to(device)
            pred, h_n, c_n = model.decoder(decoder_inputs, encoder_output, h_n, c_n)
            next_token_idx = pred.squeeze().argmax().item()
            if next_token_idx == tgt_vocab['<eos>']:
                break
            pred_seq.append(next_token_idx)
        pred_seq = tgt_vocab[pred_seq[1:]]
        tgt_seq = tgt_seq.squeeze().tolist()
        tgt_seq = tgt_vocab[tgt_seq[:tgt_seq.index(tgt_vocab['<eos>'])]] if tgt_vocab['<eos>'] in tgt_seq else tgt_vocab[tgt_seq]
        translation_results.append((' '.join(tgt_seq), ' '.join(pred_seq)))
        if len(pred_seq) >= bleu_k:
            bleu_scores.append(bleu(tgt_seq, pred_seq, k=bleu_k))

    return bleu_scores, translation_results

保持其他超参数不变,使用 NVIDIA A40 进行训练(亲测 RTX 3090 会爆掉显存),大概需要6个小时,损失函数曲线如下:

在这里插入图片描述

与之前不同的是,在评估阶段,我们会分别计算平均BLEU-{2,3,4}分数并与原先的模型进行比较

bleu_2_scores, _ = evaluate(test_loader, net, bleu_k=2)
bleu_3_scores, _ = evaluate(test_loader, net, bleu_k=3)
bleu_4_scores, _ = evaluate(test_loader, net, bleu_k=4)
print(f"BLEU-2: {np.mean(bleu_2_scores)} | BLEU-3: {np.mean(bleu_3_scores)} | BLEU-4: {np.mean(bleu_4_scores)}")

比较结果列在下表中

模型平均BLEU-2平均BLEU-3平均BLEU-4
Vanilla Seq2Seq(链接0.47990.32290.2144
Attention-based Seq2Seq(本文)0.57110.41950.3036

可以看出加入了注意力机制后,BLEU得分提升了约十个百分点

一些可以改进的地方:

  • 完全可以先将 translation_results 计算出来再计算每种BLEU得分,这样做可以大大节省时间;
  • 训练过程中Teacher Forcing的比率为100%,可以尝试降低此比率以达到更好的效果;
  • BLEU无法理解同义词,导致一些合理的翻译会被否定,可以尝试换用其他的度量来更准确地评估模型。

附录一、翻译效果比较

translation_results 中随机抽取十个。

target:     je suis plut?t occupée .
vanilla:    je suis plut?t occupé .
attn-based: je suis plut?t occupé .

target:     ?a t'arrive de dormir ?
vanilla:    t'arrive-t-il de dormir ?
attn-based: t'arrive-t-il de dormir ?

target:     je ne partirai probablement pas demain .
vanilla:    je ne vais probablement pas vouloir demain .
attn-based: je ne serai probablement pas demain .

target:     je suis prudent .
vanilla:    je suis prudente .
attn-based: je suis prudente .

target:     je suis sure que c'était juste un malentendu .
vanilla:    je suis s?r que c'était un malentendu .
attn-based: je suis s?r que ce fut un malentendu .

target:     je me demandais ce qui t'avait fait changer d'avis .
vanilla:    je me demandais ce que tu ressens .
attn-based: je me demandais ce qui aurait réussi à ce sujet .

target:     il me jeta un regard sévère .
vanilla:    il me fit une robe bleue .
attn-based: il m'a donné un grand regard .

target:     te fies-tu à qui que ce soit ?
vanilla:    vous fiez-vous à quiconque ?
attn-based: te fies-tu à quiconque ?

target:     es-tu s?re d'avoir assez chaud ?
vanilla:    es-tu s?r que tu es allé ?
attn-based: êtes-vous s?r d'avoir assez chaud ?

target:     je commen?ais à me faire du souci à ton sujet .
vanilla:    je commen?ais à m'inquiéter pour toi .
attn-based: je commen?ais à m'inquiéter à votre sujet .

附录二、完整代码

💻 完整代码请前往 eng-fra-seq2seq 进行查看。码文不易,下载时还请您随手给一个follow和star,谢谢!

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-08-19 19:04:57  更:2022-08-19 19:07:04 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/25 23:29:48-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码