IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习神经网络论文研读-自然语言处理方向-elctra -> 正文阅读

[人工智能]深度学习神经网络论文研读-自然语言处理方向-elctra

概念引入

ELECTRA比BERRT快的原因

背景
当今的SOTA的预训练语言模型,比如BERT,采用Mask language model(MLM)的方式破坏输入的内容,通过双向语言模型进行预测重构;然而这存在一个问题,那就是MASK这个token在训练中存在但是在实际预测中不存在,为了缓解这个问题,BERT采用了选择语料中15%的TOKEN,在其中80%进行MASK,10%随机替换,10%不变,这的确稍微缓解了训练预测不一致的问题(虽然在XLNet利用permutation language model得到解决),但是确使得BERT必须利用更多的训练语料,需要的算力也大幅增加,为此提出了ELECTRA这个模型解决这个问题
对应的解决方案
为了解决上述说的训练慢,数据要求多的问题,ELECTRA中训练不只是用语料中的subset(即BERT中只是MASK的token)进行预测,而是利用全部的token. 为了达成这个目的,作者训练语言模型的时候不是像bert一样把他看作generator(bert中通过重构被MASK的词,某种程度上可以看成为generator),而是看成discriminator,论文中引入另一个generator去生成相似的词进行替换,训练语言模型的任务就是去判断语料中的每个词是不是被替换了,这里有点对抗学习(GAN)的意思,但是这里并不是用GAN(因为GAN在本文和图片不一样不是连续的,将GAN用在文本生成上有难度)

摘要

大意
1.MLM的预训练方法类似于Bert的破坏输入,通过用mask标志替换,并且训练模型去重塑最原始的输
入。当它们在下游任务中产生好的结果,但是需要大量的数据。
2.我们提出一个更有效的预训练任务,方法是从小型生成器里采样的合理的替代品替换一些 token
来破坏输入。然后训练一个判别模型,该模型可以预测损坏的输入中的每个 token是否被生成器样本
取代。
3.我们的模型取和Bert相同的数据量、模型参数,所取得结果要优于Bert系列模型,并且训练的时
间大大的减少。

elctra的判别器与生成器

两个都是transformer的encoder结构,只是两个网络的尺寸不同:
在这里插入图片描述
generator-生成器:就是一个小的 masked language model(一般是 1/4 discriminator的size),该模块的具体作用是他采用了经典的bert的MLM方式:
即首先随机选取15%的tokens,替代为MASK token,(取消了bert的80%MASK,10%unchange, 10% random replaced 的操作,原因是因为没必要,因为本文中finetuning使用的discriminator)
使用generator去训练模型,使得模型预测masked token,得到corrupted tokens
generator的目标函数和bert一样,都是希望被masked的能够被还原成原本的original tokens
如上图, token,the 和 cooked 被随机选为被masked,然后generator预测得到corrupted tokens,变成了the和ate

discriminator-判别器:discriminator的接收被generator corrupt之后的输入,discriminator的作用是分辨输入的每一个token是original的还是replaced,如果generator生成的token和原始token一致,那么这个token仍然是original的
所以,对于每个token,discriminator都会进行一个二分类,最后获得loss

以上的方式被称为replaced token detection

模型训练

生成器部分
该模型采用minimize the combined loss的方式进行训练
公式:
在这里插入图片描述
输入:
在这里插入图片描述
经过generator运行之后得到编码了上下文信息的vector representation
在这里插入图片描述
对于位置 t, 其被替换为MASK,那么它的output probability(经过softmax/逻辑回归从而达到二分类的目的)为
在这里插入图片描述
对应的LOSS函数
在这里插入图片描述

判别器部分
在这里插入图片描述
输入X_correct为生成器替换后的sample序列,hD为transformer,w为权重矩阵
在这里插入图片描述

对应的loss函数
在这里插入图片描述

整个网络的loss
在这里插入图片描述

其他训练方式

论文中还提出两种训练方式
在这里插入图片描述

  • 一种是GAN:ELECTRA 以一种对抗学习的思想来训练。作者将生成器的目标函数由最小化MLM loss换成了最大化判别器在被替换token上的RTD loss。但还有一个问题,就是新的生成器loss无法用梯度下降更新生成器,于是作者用强化学习Policy Gradient的思想,将被替换token的交叉熵作为生成器的reward,然后进行梯度下降。强化方法优化下来生成器在MLM任务上可以达到54%的准确率,而之前MLE优化下可以达到65%。

  • 一种是two-stage 训练,先训练 [公式]
    n steps,然后froze住 generator,再训练 discrininator n steps

但是效果没有共同训练效果好

效果比较

在这里插入图片描述
不比Bert差而且速度快

论文意义

通过RTD的训练方式大大减少预训练时间

在轻量级模型中有着优异的表现

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-05 17:21:26  更:2021-08-05 17:23:20 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 15:11:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码