IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Attention和Transformer的Pytorch实现 -> 正文阅读

[人工智能]Attention和Transformer的Pytorch实现

神经网络中Attention和Transformer结构

通过几个核心问题,探索NLP领域关键模块的结构

问题1,写一个Attention的代码?

import torch
import torch.nn as nn

class AttentionBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(input_dim, num_heads)
        self.fc = nn.Linear(input_dim, hidden_dim)

    def forward(self, input, mask=None):
        attention_output, attention_weights = self.attention(input, input, input, mask=mask)
        output = self.fc(attention_output)
        return output, attention_weights

这个回答中,用了一个Pytorch现有的模块‘torch.nn.MultiheadAttention’,这不方便我们探索Attention的内部结构,因此我们抛出第二个问题。

问题2,写一个不需要该模块的实现

代码稍长

import torch
import torch.nn as nn

class AttentionBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.query = nn.Linear(input_dim, hidden_dim)
        self.key = nn.Linear(input_dim, hidden_dim)
        self.value = nn.Linear(input_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, input, mask=None):
        batch_size, sequence_length, _ = input.shape
        query = self.query(input)  # (batch_size, sequence_length, hidden_dim)
        key = self.key(input)  # (batch_size, sequence_length, hidden_dim)
        value = self.value(input)  # (batch_size, sequence_length, hidden_dim)

        # Calculate attention weights
        attention_weights = torch.matmul(query, key.transpose(1, 2))  # (batch_size, sequence_length, sequence_length)
        attention_weights = attention_weights / (hidden_dim ** 0.5)  # Scale attention weights
        if mask is not None:
            attention_weights = attention_weights.masked_fill(mask == 0, -1e9)  # Mask padded positions
        attention_weights = torch.softmax(attention_weights, dim=-1)  # Normalize attention weights

        # Calculate attention-weighted output
        attention_output = torch.matmul(attention_weights, value)  # (batch_size, sequence_length, hidden_dim)
        attention_output = self.fc(attention_output)  # (batch_size, sequence_length, hidden_dim)
        return attention_output, attention_weights

从代码中可以看出,attention的结构是由4个全连接层(fully-connected layer,FC)构成

将数据输入后。并行输入到其中3个FC,它们分别输出三个值Query,Key, Value,简写为Q,K,V。经过矩阵计算,大致为QxKxV, 其中QxK后防止梯度消失加了个数据约束操作(Scale,和数据维度成正相关),再和V向乘。最终QKV的计算结果再进最后1个FC,得到结果。

另外一个重点,也就是Attention的精髓,就是有一个动态参数:attention_weights。这个参数设计初衷是解决机器翻译中长句子上下文关联问题,其作用是让模型在不同时间关注数据的不同部分。

问题3 写一个Transformer的实现

机器人回答代码如下:

import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, dropout_p=0.1):
        super().__init__()
        self.attention = AttentionBlock(input_dim, hidden_dim, num_heads)
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input, mask=None):
        residual = input
        input, attention_weights = self.attention(input, mask=mask)
        input = self.dropout(input)
        input = self.norm1(input + residual)
        residual = input
        input = self.feed_forward(input)
        input = self.dropout(input)
        input = self.norm2(input + residual)
        return input, attention_weights

可以看出,Transformer是由一个attention,加dropout和norm,再加一个2层的FC网络(这里称为feed_forward),再加dropout和norm。省略常用的数据约束操作(如激活函数,Norm正则等)。

Transformer的结构为:一个attention(3个并行FC和1个FC)和一个feed_forward(2层的FC)构成。

注意其中attention并行的3个FC用了QKV计算完成关联权重,以及2层FC中间是Relu,以及attention和feed_forward之间用了dropout和norm。

补充部分:

问题4,self-attention和attention的区别?

self-attention只是attention的一个特例,区别在于输入和输出(目的)。以NLP为例,attention输入一个句子和一组权重Weights,权重用于给出句子中每个单词的关联。self-attention输入多个句子和一组权重Weights,这个权重的目的是给出这些句子的关联。

问题5,Multi-head attention中的Multi-head是什么

这个也是Attention的变种,设计初衷是并行处理NLP长句子不同部分。就是把输入分解(split)为不同的部分head输入attention即可。比如把句子按词性拆分再输入。代码可以在原有attention基础上加入上述操作。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-12-25 11:10:19  更:2022-12-25 11:13:21 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/27 17:10:40-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码