IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> BERT(二)--论文理解:BERT 模型结构详解 -> 正文阅读

[人工智能]BERT(二)--论文理解:BERT 模型结构详解

转载请注明出处:https://blog.csdn.net/nocml/article/details/124951994

传送门:
BERT(一)–论文翻译:BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
BERT(二)–论文理解:BERT 模型结构详解

Transformer系列:
Transformer(一)–论文翻译:Attention Is All You Need 中文版
Transformer(二)–论文理解:transformer 结构详解
Transformer(三)–论文实现:transformer pytorch 代码实现

1. 整体结构

??transformer是基于机器翻译任务提出的,采用了主流的encoder-decoder框架。而做为后来者的BERT,其核心特征提取模块延用了transformer中的encoder,继而又把此模块应用到了一个(两个子任务)新的任务上。个人看法,BERT在算法上并没有做很多的改进,但在算法的应用上,取得了很大的成功,奠定了预训练模型在NLP领域的主导地位。
??为了适用新的任务,BERT构建了自己的输入处理模块与输出处理模块。整体结构如下图:
在这里插入图片描述

??如上图所示模块1为输入模块,其作用是对原始的文字输入进行信息提取,分别提取了字信息,位置信息与句信息(文字是属于前一句还是后一句。)。模块2为特征提取模块,结构采用了transformer中的encoder结构,之前在transfor的文章里讲过,这里不在赘述。模块3为任务处理模块,主要是对模块2的输出做了相应的转换,以支持不同的子任务。后面的两个小节会详细讲解第1和第3个模块。

2. 输入处理

??BERT的输入与transformer的输入相比,多了一项句子特征,即当前字符是属于第一句话,还是属于第二句话。之所以增加这个特征,是因为BERT在训练时有个预测句子关系的任务。其它两个特征没有改变,还是字符特征和位置信息特征。但在采集位置信息时,做了一些调整,transfomer里使用的是正弦波,BERT里舍弃了这个相对复杂的方法,直接对位置下标做embedding。 最后把把3个embedding后的向量直接相加,得到最终的字符串表示。整个处理方法简单有效,从而也说明了特征抽取模块的学习能力的强大。这里我还是贴一下原始论文中的图。
在这里插入图片描述
其实上面这个图表达的很清楚了,但秉着事无巨细的态度,还是把相关的步骤细化一下。如下图:
在这里插入图片描述

上图中最终的输出是[sentence length,model size] 如果是输出了一个batch,那输出的shape应为[batch size, sentence length, model size](注:上面的shape是我习惯的叫法,也可以表达为[batch size, sequence length, hidden size])

3. 特征选择/学习模块

??此模块为Transformer中的encoder模块,具体参考我之前发的关于Transformer的blog

4. 输出模块

??Bert的训练有两个子任务,一个任务(NSP, Next Sentence Prediction)是预测输入中的A和B是不是上下句。另一个是预测随机mask掉的字符的任务(MLM, Masked LM)。两个子任务的输入均来自特征抽取模块,不同的是NSP任务的输入只选取了CLS对应的输出,而序列预测任务的输入则是除CLS对应位置的其它位置的数据。模型最终的损失是这两个子任务损失的加和。整体如下图。
在这里插入图片描述

4.1 NSP任务模块 Pooler

??在Pooler模块中,会取出每一句的第一个单词(CLS对应的位置数据),做全连接和激活。得到的输出用以做分类任务(NSP任务),整体流程如下图。
在这里插入图片描述

# transformers 中的源码
class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

4.2 MLM任务输出模块

??MLM 任务的输出为BERT提取模块的输出(除CLS对应位置的输出),任务的目标是预测每个被mask掉的字符的原始字符是什么。按照transformers库中的实现来理解,具体的操作可以分为两部分:

  • 第一部分为输入转换,其具体操作为先对MLM模块的输出应用一个输出输入均为hidden_size的linear层,之后应用激活函数,这里的激活函数可以有多个选择,用户可以自定义,之后再做leyer normalization。这块比较简单,直接上源码,如下:
    class BertPredictionHeadTransform(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.dense = nn.Linear(config.hidden_size, config.hidden_size)
            if isinstance(config.hidden_act, str):
                self.transform_act_fn = ACT2FN[config.hidden_act]
            else:
                self.transform_act_fn = config.hidden_act
            self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    
        def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
            hidden_states = self.dense(hidden_states)
            hidden_states = self.transform_act_fn(hidden_states)
            hidden_states = self.LayerNorm(hidden_states)
            return hidden_states
    
  • 第二部分为把第一部分的输出中的每个字符分类,预测masked字符的原始字符,其操作是使用输入大小为hidden size, 输出大小为vocab size 的linear层对字符进行分类。源码如下:
    class BertLMPredictionHead(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.transform = BertPredictionHeadTransform(config)
    
            # The output weights are the same as the input embeddings, but there is
            # an output-only bias for each token.
            self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    
            self.bias = nn.Parameter(torch.zeros(config.vocab_size))
    
            # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
            self.decoder.bias = self.bias
    
        def forward(self, hidden_states):
            hidden_states = self.transform(hidden_states)
            hidden_states = self.decoder(hidden_states)
            return hidden_states
    
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-09-25 23:12:38  更:2022-09-25 23:13:19 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 18:24:55-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计