IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Seq2Seq中常见注意力机制的实现 -> 正文阅读

[人工智能]Seq2Seq中常见注意力机制的实现

引言

本文通过Pytorch实现了Seq2Seq中常用的注意力方式。

注意力方式

s c o r e ( h t , h  ̄ s ) = { h t T h  ̄ s dot h t T W a h  ̄ s general v a T tanh ? ( W a [ h t ; h  ̄ s ] ) concat v a T tanh ? ( W a h  ̄ s + U a h t ) bahdanau score(h_t, \overline{h}_s) = \begin{cases} h_t^T \overline{h}_s & \text{dot} \\ h_t^T W_a \overline{h}_s & \text{general} \\ v_a^T \tanh (W_a[h_t; \overline{h}_s]) & \text{concat} \\ v_a^T \tanh (W_a\overline{h}_s + U_a h_t) & \text{bahdanau} \end{cases} score(ht?,hs?)=??????????htT?hs?htT?Wa?hs?vaT?tanh(Wa?[ht?;hs?])vaT?tanh(Wa?hs?+Ua?ht?)?dotgeneralconcatbahdanau?

结合论文Effective Approaches to Attention-based Neural Machine TranslationNEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE,我们得到上面四种计算注意力的方式。

编码器的每个输出 h i h_i hi?对应的权重 α i j \alpha_{ij} αij?通过如下公式计算:
α i j = e x p ( e i j ) ∑ k = 1 T x e x p ( e i k ) (6) \alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x} exp(e_{ik})} \tag{6} αij?=k=1Tx??exp(eik?)exp(eij?)?(6)
其中
e i j = a ( s i ? 1 , h j ) e_{ij} = a(s_{i-1},h_j) eij?=a(si?1?,hj?)

(论文翻译) NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE

代码实现

import torch.nn as nn
import torch


class Attention(nn.Module):

    def __init__(self, hidden_size, method='dot'):
        super(Attention, self).__init__()

        self.method = method
        self.hidden_size = hidden_size

        if self.method not in ['dot', 'general', 'concat', 'bahdanau']:
            raise ValueError(self.method, "is not an appropriate attention method.")

        if self.method == 'general':
            self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
        elif self.method == 'concat':
            self.Wa = nn.Linear(hidden_size * 2, hidden_size, bias=False)
            self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))
        elif self.method == 'bahdanau':
            self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
            self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
            self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def _score(self, last_hidden, encoder_outputs):
        '''

        :param last_hidden: 解码器最后一层(若有多层的话)的输出 [1,batch_size,hidden_size] 解码器一次只处理一个时间步,并且只有一个方向: D=1
        :param encoder_outputs: 编码器所有时间步的隐藏状态 [seq_len, batch_size, hidden_size]
        '''

        if self.method == 'dot':
            # last_hidden * encoder_outputs [seq_len, batch_size, hidden_size]
            # sum(x, dim=2) 将第2个维度的值累计,累计第2个维度的值,使其维度大小变成1,并移除,得到 [seq_len, batch_size]
            # 计算每个批次内, 解码器当前时间步 与编码器每个时间步的 权重得分
            # 计算e_i
            return torch.sum(last_hidden * encoder_outputs, dim=2)  # [seq_len, batch_size]
        elif self.method == 'general':
            energy = self.Wa(last_hidden)  # [1, batch_size, hidden_size]
            # [seq_len, batch_size, hidden_size] x [1, batch_size, hidden_size] = [seq_len, batch_size, hidden_size]
            return torch.sum(encoder_outputs * energy, dim=2)  # [seq_len, batch_size]

        elif self.method == 'concat':
            # last_hidden.expand(encoder_outputs.size(0), -1, -1)) # [seq_len, batch_size, hidden_size] 对维度0进行复制操作
            # 复制seq_len份,以支持cat操作
            # cat(*, dim=2)   [seq_len, batch_size, hidden_size*2]
            # energy = tanh(self.Wa(*))  [seq_len,batch_size, hidden_size]
            energy = torch.tanh(
                self.Wa(torch.cat((encoder_outputs, last_hidden.expand(encoder_outputs.size(0), -1, -1)), dim=2)))
            return torch.sum(self.va * energy, dim=2)  # [seq_len, batch_size]

        else:  # method == 'bahdanau'
            # self.Wa(last_hidden)  [1,batch_size,hidden_size]
            # self.Ua(encoder_outputs) [seq_len, batch_size, hidden_size]
            # torch.tanh(*)  [seq_len, batch_size, hidden_size]
            energy = torch.tanh(self.Wa(last_hidden) + self.Ua(encoder_outputs))
            return torch.sum(self.va * energy, dim=2)  # [seq_len, batch_size]

    def forward(self, last_hidden, encoder_outputs):
        # 注意力得分,见_score方法,返回的大小都是 [seq_len, batch_size]
        attn_energies = self._score(last_hidden, encoder_outputs)
        # 转置 [batch_size, seq_len]
        attn_energies = attn_energies.t()
        # 经过softmax,得到权重系数,我们要计算对每个时间步的权重,所以沿着时间步的维度计算
        # 并且计算之后,形状保持不变。
        # 计算上面公式(6) α_i
        return torch.softmax(attn_energies, dim=1) \
            .unsqueeze(1)  # unsqueeze(1) 在dim=1处,扩展一个维度,形状变成 [batch_size, 1, seq_len]


  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-08 11:20:38  更:2021-08-08 11:21:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 3:57:57-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码