IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 知识蒸馏Knownledge Distillation -> 正文阅读

[人工智能]知识蒸馏Knownledge Distillation

知识蒸馏源自Hinton et al.于2014年发表在NIPS的一篇文章:Distilling the Knowledge in a Neural Network

1. 背景

一般情况下,我们在训练模型的时候使用了大量训练数据和计算资源来提取知识,但这不方便在工业中部署,原因有二:
(1)大模型推理速度慢
(2)对设备的资源要求高(大内存)
因此我们希望对训练好的模型进行压缩,在保证推理效果的前提下减小模型的体量,知识蒸馏(Knownledge Distillation)属于模型压缩的一种方法 [1]。

2. 知识蒸馏

名词解释:
cumbersome model:原始模型或者说大模型,但在后续的论文中一般称它为teacher model;
distilled model:蒸馏后的小模型,在后续的论文中一般称它为stududent model;
hard targets:像[1, 0, 0]这样的标签,也叫做ground-truth label;
soft targets:像[0.7, 0.2, 0.1]这样的标签;
transfer set:训练student model的数据

好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让student学习到teacher的泛化能力,理论上得到的结果会比单纯拟合训练数据的student要好 [3]。显然,soft target可以提供更大的信息熵,所以studetn model可以学习到更多的信息。

通俗的来讲,粗暴的使用one-hot编码把原本有帮助的类内variance和类间distance都忽略了,比如猫和狗的相似性要比猫与摩托车的相似性要多,狗的某些特征可能对识别猫也会有帮助(比如毛发),因此使用soft target可以恢复被one-hot编码丢弃的信息 [2]。

在Hinton et al. 发表的这篇论文中,作者提出了"softmax temperature"的概念,其公式为:
q i = exp ? ( z i / T ) ∑ j exp ? ( z j / T ) q_{i}=\frac{\exp (z_{i}/T)}{\sum_{j}^{}\exp (z_{j}/T)} qi?=j?exp(zj?/T)exp(zi?/T)?
Python代码:

import numpy as np
def softmax_t(x,t):
	x_exp = np.exp(x / t)
	return x_exp / np.sum(x_exp)

q i q_{i} qi?代表第 i i i类的输出概率, z i z_{i} zi? z j z_{j} zj?为softmax的输入,即上一层神经元的输出(logits),T表示temperature参数。通常情况下,我们使用的softmax函数T为1,但 T T T可以控制输出soft的程度。比如对于 z = [ 0.3 , 0.5 , 0.8 , 0.1 , 0.2 ] z=[0.3, 0.5, 0.8, 0.1, 0.2] z=[0.3,0.5,0.8,0.1,0.2],我们分别取 T = [ 0.5 , 1 , 5 , 20 ] T=[0.5, 1, 5, 20] T=[0.5,1,5,20],然后画出softmax函数的输出可以看到, T T T越小,输出的预测结果越“硬”(曲线更加曲折),T越大输出的结果越“软”(曲线更加平和)。

softmax.png

插一句题外话,为什么这里的参数是叫温度(temperature)呢?这和蒸馏(distillation)这一热力学工艺有关。在蒸馏工艺中,温度越高提取到的物质越纯越浓缩。而在知识蒸馏中,参数T越大(温度越高),teacher model产生的label越"soft",信息熵就越高,提炼的知识更具有一般性(generalization)。所以说作者将这一参数取名temperature十分有趣。

知识蒸馏示意图,图片来源:https://intellabs.github.io/distiller/knowledge_distillation.html

知识蒸馏的实现过程可以概括为:

  1. 训练teacher model;
  2. 使用高温T将teacher model中的知识蒸馏到student model(在测试时温度T设为1)。

student modeld的目标函数由一下两项的加权平均组成:

  1. distillation loss:soft targets(由teacher model产生) 和student model的soft predictions的交叉熵,这里的T使用的是和训练teacher model相同的值。(保证student model和teacher model的结果尽可能一致)
  2. student loss:hard targets 和student model的输出数据的交叉熵,但T设置为1。(保证student model的结果和实际类别标签尽可能一致)

总体的损失函数可以写作:
L ( x , W ) = α ? CE ( y , σ ( z s ; T = 1 ) ) + β ? CE ( σ ( z t ; T = τ ) , σ ( z s , T = τ ) ) \mathcal{L}(x,W)=\alpha \ast \text{CE}(y,\sigma(z_{s};T=1))+\beta \ast \text{CE}(\sigma (z_{t};T=\tau ),\sigma(z_{s},T=\tau)) L(x,W)=α?CE(y,σ(zs?;T=1))+β?CE(σ(zt?;T=τ),σ(zs?,T=τ))
其中, x x x表示输入, W W W表示student model的参数, y y y是ground-truth label, CE \text{CE} CE是交叉熵损失函数, σ \sigma σ是刚刚提到的softmax temperature激活函数, z s z_{s} zs? z t z_{t} zt?分别表示student和teacher model神经元的输出(logits), α \alpha α β \beta β表示两个权重参数 [4].

原论文指出, α \alpha α要比 β \beta β相对小一些可以取得更好的结果,因为在求梯度时soft targets被缩放了 1 / T 2 1/T^{2} 1/T2,所以第2项要乘以一个更小的权值来平衡二者在优化时的比重 [1].

换一个角度来想,这里的知识蒸馏其实是相对于对于原始交叉熵添加了一个正则项:
L ( x , W ) = CE ( y , y ^ ) + λ soft_loss ( y ′ , y ^ ) \mathcal {L}(x,W)=\text{CE}(y,\hat{y})+\lambda \text{soft\_loss}(y', \hat{y}) L(x,W)=CE(y,y^?)+λsoft_loss(y,y^?)
利用teacher model的先验知识对student model进行正则化 [5]。

本文原载于我得简书,未经授权,不得转载。


References:

[1] Distilling the Knowledge in a Neural Network.
[2] # Distilling the Knowledge in a Neural Network 论文笔记
[3] 深度神经网络模型蒸馏Distillation
[4] Knowledge Distillation
[5] 神经网络知识蒸馏 Knowledge Distillation

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-07 11:11:00  更:2022-05-07 11:13:14 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/4 15:54:15-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码