IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【模型蒸馏】TinyBERT: Distilling BERT for Natural Language Understanding -> 正文阅读

[人工智能]【模型蒸馏】TinyBERT: Distilling BERT for Natural Language Understanding

总述

TinyBert主要探究如何使用模型蒸馏来实现BERT模型的压缩。
主要包括两个创新点:

  1. 对Transformer的参数进行蒸馏,需要同时注意embedding,attention_weight, 过完全连接层之后的hidden,以及最后的logits。
  2. 对于预训练语言模型,要分为pretrain_model 蒸馏以及task-specific蒸馏。分别学习pretrain模型的初始参数以便给压缩模型的参数一个好的初始化,第二步学习pretrain model fine-tuning的logits让压缩模型再次学习。

模型

模型主要分成三个部分:

  1. Transformer Layer的蒸馏
    主要蒸馏两部分,第一是每一层的attention weight,第二是每一层输出的hidden。如下图所示。
    在这里插入图片描述

公式:
在这里插入图片描述
在这里插入图片描述
使用均方误差作为损失函数, 并且在hidden对比的时候引入了一个Wh,这是因为学生模型和老师模型的向量编码维度不一致(学生模型的向量维度要更小)
2. Embedding layer的蒸馏
在这里插入图片描述
E表示embeddign层的输出。
3. Predict logits的蒸馏
在这里插入图片描述
z表示老师模型与学生模型在task-specific任务上的预测概率。

此外还有一个细节便是数据增强,学生模型在task-specific任务上fine-tuning的时候,Tinybert对原数据集做了数据增强。(ps:这其实非常奇怪,因为后文实验中可以看到,去除数据增强之后,模型的效果比之之前的sota并无太大提升。而文章主要的卖点是模型蒸馏ummm)

实验和结论

  1. 蒸馏各个层次的重要性
    在这里插入图片描述
    可以看出,从重要性来说: Attn > Pred logits > Hidn > emb. 其中,Attn,Hidn以及emb在两个阶段的蒸馏中均有用到。

  2. 数据增强的重要性
    在这里插入图片描述
    GD (General Distillation)表示第一阶段蒸馏。
    TD (Task-specific Distillation)表示第二阶段蒸馏.
    and DA (Data Augmentation).表示数据增强。
    这张表得到的结论是,数据增强很重要 : (。

  3. 学生模型需要学习老师模型的哪些层
    在这里插入图片描述
    假设学生模型4层,老师模型12层
    top表示学生模型学习老师的后4层(10,11,12),bottom表示学习老师模型的前4层(1,2,3,4),uniform表示均匀学习(等间距,3,6,9,12)。
    可以看到,均匀学习各层的效果更好。

代码

# 此部分代码应该写在Trainer里面, loss.backward之前。
# 获取学生模型的logits, attention_weight以及hidden
 student_logits, student_atts, student_reps = student_model(input_ids, segment_ids, input_mask,
                                                            is_student=True)
# 在测试环境下获取老师模型的logits, attention_weight以及hidden
 with torch.no_grad():
     teacher_logits, teacher_atts, teacher_reps = teacher_model(input_ids, segment_ids, input_mask)

# 分为两步,一步是学习attentino_weight和hidden,还有一步是学习predict_logits。 总思想就是对学生模型的输出和老师模型的输出做loss,其中针对attention_weight和hidden是MSE loss, 针对logits是交叉熵。
 if not args.pred_distill:
     teacher_layer_num = len(teacher_atts)
     student_layer_num = len(student_atts)
     assert teacher_layer_num % student_layer_num == 0
     layers_per_block = int(teacher_layer_num / student_layer_num)
     new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
                         for i in range(student_layer_num)]

     for student_att, teacher_att in zip(student_atts, new_teacher_atts):
         student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device),
                                   student_att)
         teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device),
                                   teacher_att)

         tmp_loss = loss_mse(student_att, teacher_att)
         att_loss += tmp_loss

     new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)]
     new_student_reps = student_reps
     for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
         tmp_loss = loss_mse(student_rep, teacher_rep)
         rep_loss += tmp_loss

     loss = rep_loss + att_loss
     tr_att_loss += att_loss.item()
     tr_rep_loss += rep_loss.item()
 else:
     if output_mode == "classification":
         cls_loss = soft_cross_entropy(student_logits / args.temperature,
                                       teacher_logits / args.temperature)
     elif output_mode == "regression":
         loss_mse = MSELoss()
         cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1))

     loss = cls_loss
     tr_cls_loss += cls_loss.item()
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-26 15:18:31  更:2022-05-26 15:18:36 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 4:43:16-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码