深度学习神经网络论文研读-自然语言处理方向-elctra-目录
概念引入
ELECTRA比BERRT快的原因
背景 当今的SOTA的预训练语言模型,比如BERT,采用Mask language model(MLM)的方式破坏输入的内容,通过双向语言模型进行预测重构;然而这存在一个问题,那就是MASK这个token在训练中存在但是在实际预测中不存在,为了缓解这个问题,BERT采用了选择语料中15%的TOKEN,在其中80%进行MASK,10%随机替换,10%不变,这的确稍微缓解了训练预测不一致的问题(虽然在XLNet利用permutation language model得到解决),但是确使得BERT必须利用更多的训练语料,需要的算力也大幅增加,为此提出了ELECTRA这个模型解决这个问题 对应的解决方案 为了解决上述说的训练慢,数据要求多的问题,ELECTRA中训练不只是用语料中的subset(即BERT中只是MASK的token)进行预测,而是利用全部的token. 为了达成这个目的,作者训练语言模型的时候不是像bert一样把他看作generator(bert中通过重构被MASK的词,某种程度上可以看成为generator),而是看成discriminator,论文中引入另一个generator去生成相似的词进行替换,训练语言模型的任务就是去判断语料中的每个词是不是被替换了,这里有点对抗学习(GAN)的意思,但是这里并不是用GAN(因为GAN在本文和图片不一样不是连续的,将GAN用在文本生成上有难度)
摘要
大意 1.MLM的预训练方法类似于Bert的破坏输入,通过用mask标志替换,并且训练模型去重塑最原始的输 入。当它们在下游任务中产生好的结果,但是需要大量的数据。 2.我们提出一个更有效的预训练任务,方法是从小型生成器里采样的合理的替代品替换一些 token 来破坏输入。然后训练一个判别模型,该模型可以预测损坏的输入中的每个 token是否被生成器样本 取代。 3.我们的模型取和Bert相同的数据量、模型参数,所取得结果要优于Bert系列模型,并且训练的时 间大大的减少。
elctra的判别器与生成器
两个都是transformer的encoder结构,只是两个网络的尺寸不同: generator-生成器:就是一个小的 masked language model(一般是 1/4 discriminator的size),该模块的具体作用是他采用了经典的bert的MLM方式: 即首先随机选取15%的tokens,替代为MASK token,(取消了bert的80%MASK,10%unchange, 10% random replaced 的操作,原因是因为没必要,因为本文中finetuning使用的discriminator) 使用generator去训练模型,使得模型预测masked token,得到corrupted tokens generator的目标函数和bert一样,都是希望被masked的能够被还原成原本的original tokens 如上图, token,the 和 cooked 被随机选为被masked,然后generator预测得到corrupted tokens,变成了the和ate
discriminator-判别器:discriminator的接收被generator corrupt之后的输入,discriminator的作用是分辨输入的每一个token是original的还是replaced,如果generator生成的token和原始token一致,那么这个token仍然是original的 所以,对于每个token,discriminator都会进行一个二分类,最后获得loss
以上的方式被称为replaced token detection
模型训练
生成器部分 该模型采用minimize the combined loss的方式进行训练 公式: 输入: 经过generator运行之后得到编码了上下文信息的vector representation 对于位置 t, 其被替换为MASK,那么它的output probability(经过softmax/逻辑回归从而达到二分类的目的)为 对应的LOSS函数
判别器部分 输入X_correct为生成器替换后的sample序列,hD为transformer,w为权重矩阵
对应的loss函数
整个网络的loss
其他训练方式
论文中还提出两种训练方式
- 一种是GAN:ELECTRA 以一种对抗学习的思想来训练。作者将生成器的目标函数由最小化MLM loss换成了最大化判别器在被替换token上的RTD loss。但还有一个问题,就是新的生成器loss无法用梯度下降更新生成器,于是作者用强化学习Policy Gradient的思想,将被替换token的交叉熵作为生成器的reward,然后进行梯度下降。强化方法优化下来生成器在MLM任务上可以达到54%的准确率,而之前MLE优化下可以达到65%。
- 一种是two-stage 训练,先训练
n steps,然后froze住 generator,再训练 discrininator n steps
但是效果没有共同训练效果好
效果比较
不比Bert差而且速度快
论文意义
通过RTD的训练方式大大减少预训练时间
在轻量级模型中有着优异的表现
|