本文介绍一篇自然语言处理剪枝论文(Network Pruning Rethinking)
上图介绍了在不同的修剪策略下,知识是如何传递的:
(a) 一般的预培训和微调程序。
g
g
g是一个编码器。
g
L
g_L
gL?和
g
L
D
g_{L_D}
gLD??分别是在预训练数据集和微调数据集上训练良好的编码器。
L
L
L和
D
D
D分别是通用语言知识和任务特定知识。预训练和测试之间存在域误差,微调和测试之间存在泛化误差。
(b) 和 ( c)是两种基本的修剪策略。
L
D
L_D
LD?和
L
p
r
L_{pr}
Lpr?都是知识l的子集。
L
D
L_D
LD?与下游任务相关。
L
p
r
L_{pr}
Lpr?保存在经过修剪的编码器
g
(
L
p
r
)
g_{(L^{pr})}
g(Lpr)?中。
(d) 是作者提出的修剪策略。
(
L
p
r
)
D
(L^{pr})_D
(Lpr)D?为先修剪后微调获得的知识。
(
L
D
)
p
r
(L_D)^{pr}
(LD?)pr对应于蒸馏时先微调后修剪。
1. 一般的预培训和微调程序
- 在预训练过程,通过大量数据实例
(
x
p
,
y
p
)
(x^p, y^p)
(xp,yp)学习通用语言知识,用
L
L
L表示。
L
L
L包含一个与下游任务相关的子集,用
L
D
L_D
LD?表示,
L
L
L的数量远远大于
L
D
L_D
LD?的数量。
- 为了将知识
L
L
L(特别是
L
D
L_D
LD?)从预训练域转移到下游域,使用经过良好训练的编码器
g
L
g_L
gL?对下游编码器
g
L
D
g_{L_D}
gLD??进行初始化。
- 在微调过程中,下游编码器的训练是基于来自下游域的少量数据示例
(
x
d
,
y
d
)
(x^d, y^d)
(xd,yd)中保留的任务相关知识
D
D
D。
- 最后,根据测试数据对经过良好训练的下游编码器
g
L
D
g_{L_D}
gLD??进行评估。
2. 微调过程中修剪
一是在微调过程中对下游编码器
g
L
g_L
gL?进行修剪:
但是,由于优化过程中权值更新的损失仅基于下游任务域的数据示例
(
x
d
,
y
d
)
(x^d, y^d)
(xd,yd),这个数据相比于大量数据实例
(
x
p
,
y
p
)
(x^p, y^p)
(xp,yp)是很小的,所以知识
L
D
L_D
LD?很依赖于
g
L
g_L
gL?赋予的初始值,对
g
L
g_L
gL?进行修改就可能回破坏
L
D
L_D
LD?。
3. 预训练阶段修剪
另一种策略是在预训练阶段执行修剪:
生成的剪枝网络保留了知识
L
L
L的一个子集,用
L
p
r
L_{pr}
Lpr?表示。不幸的是,由于该策略忽略了下游任务信息,且
L
L
L的数量非常大,即
L
L
L远大于
L
p
r
L_{pr}
Lpr?,
L
p
r
L_{pr}
Lpr?的知识可能与我们希望保存的
L
D
L_D
LD?迥然不同。如图所示:
4. 作者提出的修剪网络
为了减少
L
D
L_D
LD?的损失,作者在修剪过程中利用知识蒸馏。使用特定任务的精细调整的语言表示模型BERT(论文链接)作为教师网络,预先训练的BERT作为学生网络。SparseBERT在蒸馏阶段进行修剪。
学生网络在预训练之后,跳过传统网络的微调的步骤,对预训练的编码器
g
L
g_L
gL?进行蒸馏,蒸馏时先微调后修剪,得到
(
L
D
)
p
r
(L_D)^{pr}
(LD?)pr。同时,教师网络基于下游数据集
(
x
d
,
y
d
)
(x^d, y^d)
(xd,yd),根据传统方法对编码器
g
L
g_L
gL?进行修剪,得到
g
L
D
g_{L_D}
gLD??。
这样,如下图所示,利用教师网络保留
L
D
L_D
LD?。通过将下游任务数据
(
x
d
,
y
d
)
(x^d, y^d)
(xd,yd)输入教师-学生框架,我们帮助学生模仿老师的行为,尽可能多地学习
L
D
L_D
LD?和
L
L
L。
|