| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> 【轻量化深度学习】知识蒸馏与NLP语言模型的结合 -> 正文阅读 |
|
[人工智能]【轻量化深度学习】知识蒸馏与NLP语言模型的结合 |
Knowledge DistillationStudent : Wenxuan Zeng School : University of Electronic Science and Technology of China ? Date : 2022.3.25 - 2022.4.3 文章目录参考论文: Distilling the Knowledge in a Neural Network 这篇论文是知识蒸馏的开山之作,发表于NIPS’14,非常值得我们去学习研究。所以我先从这篇论文入手去学习知识蒸馏,然后去学习如何使用知识蒸馏去压缩BERT模型。 1 Knowledge的定义如果说知识就是模型中的参数,那么将难以迁移,因为两个不同的模型并没有一一对应的参数。教师网络预测结果中各个类别概率的相对大小隐式地包含了知识,在文中也称知识是从输入向量到输出向量的映射。举个直观的例子,对于小轿车的图片,模型会给出所有物体的预测概率,比如会有一部分概率是公交车,减小一部分概率是胡萝卜,那么教师网络就能教给学生网络这样的知识——这张图片大概率是一辆小轿车,不太可能是公交车或胡萝卜,并且这张图片更像公交车,而更不像胡萝卜。实际上,就是表明知识包含正确信息,同时也包含错误信息之间的相对关系。 2 Soft targets一种将笨重模型的泛化能力迁移到小模型的方式就是将笨重模型所产生的类别概率作为soft targets来训练小模型。Soft targets中包含了较高的熵,所以提供了更为详细的信息;而hard target (one-hot encoding)则熵低,提供较少的信息。 什么是Soft/Hard targets?举个例子,在三分类问题中,小轿车的hard target也许能表达成这样:(0, 0, 1),那soft target也许是这样的:(0.1, 0.3, 0.6)。显然,soft targets中包含了更多的信息,比如之前提到的“这张图片更像公交车,而更不像胡萝卜”类似的相对信息。 Hard loss: 3 T-Softmax复习一下softmax的作用,在做分类任务时,通过softmax将所有类别的概率压缩到 [0,1] 的范围内,并且概率值求和为1。Softmax表达式如下:
q
i
=
e
x
p
(
z
i
)
∑
j
e
x
p
(
z
j
)
q_i=\frac{exp(z_i)}{\sum_j exp(z_j)}
qi?=∑j?exp(zj?)exp(zi?)?
q
i
=
e
x
p
(
z
i
/
T
)
∑
j
e
x
p
(
z
j
/
T
)
q_i=\frac{exp(z_i/T)}{\sum_j exp(z_j/T)}
qi?=∑j?exp(zj?/T)exp(zi?/T)? 结论:T越大,得到的预测结果越soft,各个类别的概率值越接近,所以其中包含的知识会更多。 4 知识蒸馏4.1 蒸馏流程下图是知识蒸馏的过程,教师网络在温度为t的时候训练,得到soft labels,学生网络是温度为t的时候训练,得到soft predictions,通过拟合soft labels和soft predictions,引导学生网络学习教师网络学到的知识**(比喻soft labels是老师的言传身教)。另外,学生网络在温度为1的时候训练,得到hard prediction,也就是one-hot encoding,然后用交叉熵损失函数与hard label计算出student loss(比喻hard label是课本知识)**。 4.2 Loss functionL = γ L h a r d + ( 1 ? γ ) T 2 L s o f t L = \gamma L_{hard} + (1-\gamma)T^2 L_{soft} L=γLhard?+(1?γ)T2Lsoft? 注意,在soft loss处需要乘上 T 2 T^2 T2,改变用于蒸馏的温度,硬目标和软目标的相对贡献大致保持不变。 4.3 预测值匹配是一种特殊形式的知识蒸馏在 Model Compression (SIGKDD’06) 这篇论文中,作者通过知识迁移实现了模型的压缩,详细来说就是将教师网络和学生网络的logits求得MSE。而在本文中,作者说这种压缩方式是蒸馏的一个特例。
?
C
?
z
i
=
1
T
(
q
i
?
p
i
)
=
1
T
(
e
z
i
/
T
∑
j
e
z
j
/
T
?
e
v
i
/
T
∑
j
e
v
j
/
T
)
\frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T} (\frac{e^{z_i/T}}{\sum_j e^{z_j/T}} - \frac{e^{v_i}/T}{\sum_j e^{v_j}/T})
?zi??C?=T1?(qi??pi?)=T1?(∑j?ezj?/Tezi?/T??∑j?evj?/Tevi?/T?) 假设蒸馏温度T足够高,那么根据泰勒展开: e x = 1 + x e^x=1+x ex=1+x
?
C
?
z
i
=
1
T
(
q
i
?
p
i
)
=
1
T
(
1
+
z
i
/
T
N
+
∑
j
z
j
/
T
?
1
+
v
i
/
T
N
+
∑
j
v
j
/
T
)
\frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{1+{z_i}/T}{N+\sum_j {z_j}/T}-\frac{1+{v_i}/T}{N+\sum_j {v_j}/T})
?zi??C?=T1?(qi??pi?)=T1?(N+∑j?zj?/T1+zi?/T??N+∑j?vj?/T1+vi?/T?)
?
C
?
z
i
≈
1
N
T
2
(
z
i
?
v
i
)
\frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2}(z_i-v_i)
?zi??C?≈NT21?(zi??vi?) 但是在实际情况下,并不能做到温度无穷大。 下图可以看出,温度太小的时候,很小的logits对应的softmax值被压到0,没有话语权,无法发挥蒸馏的效果;而温度太大的时候,所有类别的概率趋同,可能带来噪声。 温度T用多大比较好,这个需要靠经验决定,一般来说中间温度效果最佳。 4.4 知识蒸馏简单计算5 实验设计很奇妙的是,学生网络可以做到零样本学习,比如学生网络没有见过CNN中的平移不变性知识,但是仍然可以通过教师网络的知识迁移去学到。把数字3从学生网络的训练中抹掉,学生网络仍然可以从教师网络的知识中学到3的特征(作者手动调大了bias)。 6 知识蒸馏发展方向
7 知识蒸馏在NLP领域的研究在这部分,我选出了几篇非常经典的BERT蒸馏的论文,然后对BERT蒸馏的思想进行学习,下面是我的一些学习记录。 7.1 Distilled BiLSTM链接: Distilling Task-Specific Knowledge from BERT into Simple Neural Networks 方法: 教师模型采用fine-tune的BERT-LARGE模型,学生模型采用BiLSTM+ReLU,蒸馏目标是学生模型与hard labels的交叉熵+与BERT-LARGE的logits之间的MSE。 7.2 BERT-PKD链接: Patient knowledge distillation for bert model compression (ACL’19) 方法: 不直接从模型的最后一层进行蒸馏,而是从教师模型的中间层提取知识进行蒸馏。本文提出了两种不同的蒸馏方式:Skip-k层的蒸馏方式和最后k层的蒸馏方式。 通过交叉熵损失函数定义学生和教师模型之间的预测值差距: 除了让学生模仿教师,还定义了任务相关的交叉熵损失函数: 另外,还定义了标准化后的隐藏状态的MSE loss作为损失函数: 7.2 DistillBERT链接: Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter (NIPS’19) 方法: 在预训练阶段采用知识蒸馏技术压缩BERT,为了利用预训练时从大模型中学到的归纳偏差,引入了结合了语言建模、蒸馏和余弦距离损失的三元loss。 7.3 TinyBERT链接:Tinybert: Distilling bert for natural language understanding (ACL’20) 方法:提出了two-stage learning framework,分别在预训练和fine-tune阶段蒸馏教师模型,得到了参数量减少7.5倍,速度提升9.4倍的4层BERT,效果可以达到教师模型的96.8%,同时这种方法训出的6层模型甚至接近BERT-base,超过了BERT-PKD和DistillBERT。本文提出注意力矩阵的蒸馏,用MSE作为损失函数拟合教师和学生的注意力矩阵。 同时对embedding layer和hidden layer都做知识蒸馏,同样采用MSE作为损失函数: 最后,用交叉熵损失函数去衡量教师和学生模型的logits差距: 综合上面提到的蒸馏目标,根据蒸馏的layer,决定采用哪个蒸馏的loss: 7.4 MobileBERT链接: MobileBERT:a Compact Task-Agnostic BERT for Resource-Limited Devices (ACL’20) 方法: 采用了瓶颈结构和自注意力与前馈神经网络的平衡机制,将知识从教师模型蒸馏到学生模型,使模型具有更窄的宽度。(具体笔记在前面文档中的Paper Understanding部分有写) 7.5 MiniLM方法: 虽然之前的文章把模型蒸馏了个遍,从embeddin layer到hidden layer,又到attention layer,最后到prediction layer,但是本文仍然找了一个新的点去蒸馏,并取得了非常好的效果。这篇文章蒸馏self-attention模块,提出value之间的scaled dot-product (value-relation) 作为新的深度自注意力知识。另外,本文用了一个teacher assistant去辅助大模型的蒸馏。 自注意力矩阵之间的关系用KL散度来衡量: 下面是本文定义的value-relation,实际上就是对value做scaled dot product,然后用KL散度衡量两个VR矩阵: 结论: 这篇文章说“只蒸馏最后一层效果比layer-to-layer要好,而且不用严格去对应两个模型的每一层,只蒸馏最后一层也能提高学生的性能,并使学生具有更强的泛化能力”。 8 我对知识蒸馏的思考 ?
前面的总结是之前写的,这里是来自两周后的补充: |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 | -2025/1/6 18:07:28- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |