IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 蒸馏(Knowledge Distillation) -> 正文阅读

[人工智能]蒸馏(Knowledge Distillation)

本文介绍知识蒸馏(Knowledge Distillation)。核心思想是通过迁移知识,从而通过训练好的大模型得到更加适合推理的小模型。

文章的核心思想就是提出用soft target来辅助hard target一起训练,而soft target来自于大模型的预测输出:

1、训练大模型:先用hard target,也就是正常的标签训练大模型。
2、计算soft target:利用训练好的大模型来计算soft target。也就是大模型“软化后”再经过softmax的输出。
3、训练小模型,在小模型的基础上再加一个额外的soft target的损失函数,通过alpha来调节两个损失函数的比重。
4、预测时,将训练好的小模型按常规方式(右图)使用。

在这里插入图片描述

1. 网络结构

定义一个teacher网络,由两个卷积、两个池化、一个全连接层组成。

class anNet_deep(nn.Module):
    def __init__(self):
        super(anNet_deep,self).__init__()
        self.conv1 = nn.Sequential(
                nn.Conv2d(1,64,3,padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU())
        self.conv2 = nn.Sequential(
                nn.Conv2d(64,64,3,1,padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU())
        self.conv3 = nn.Sequential(
                nn.Conv2d(64,128,3,1,padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU())
        self.conv4 = nn.Sequential(
                nn.Conv2d(128,128,3,1,padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU())
        self.pooling1 = nn.Sequential(nn.MaxPool2d(2,stride=2))
        self.fc = nn.Sequential(nn.Linear(6272,10))
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pooling1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pooling1(x)
        x = x.view(x.size()[0],-1)
        x = self.fc(x)
        return x
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                m.bias.data.zero_()

定义一个student网络,由一个卷积层、池化层、全连接层构成。

class anNet(nn.Module):
    def __init__(self):
        super(anNet,self).__init__()
        self.conv1 = nn.Conv2d(1,6,3)
        self.pool1 = nn.MaxPool2d(2,1)
        self.fc3 = nn.Linear(3750,10)
    def forward(self,x):
        x = self.conv1(x)
        x = self.pool1(F.relu(x))
        x = x.view(x.size()[0],-1)
        x = self.fc3(x)
        return x
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                m.bias.data.zero_()

2. 损失函数

知识蒸馏的关键是损失函数的设计,它包括普通的交叉熵损失和建立在soft target基础上的损失。

hard target 包含的信息量(信息熵)很低,soft target包含的信息量大,拥有不同类之间关系的信息。

比如,同时分类驴和马的时候,尽管某张图片是马,但是soft target就不会像hard target那样只有马的index处的值为1,其余为0,而是在驴的部分也会有概率。

这样的好处是,这个图像可能更像驴,而不会去像汽车或者狗之类的,而这样的soft信息存在于概率中,以及标签之间的高低相似性都存在于soft target中。

但是如果soft target是像这样的信息[0.98 0.01 0.01],就意义不大了,所以需要在softmax中增加温度参数T(这个设置在最终训练完之后的推理中是不需要的)。增加softmax后的蒸馏损失函数:

在这里插入图片描述
综合损失函数:
在这里插入图片描述

# ==============================经典损失=============================== 
# 经典损失
outputs = model(inputs.float())
criterion = nn.CrossEntropyLoss()
loss1 = criterion(outputs, labels)

# ==============================蒸馏损失=============================== 
# 蒸馏损失衡量的是student网络输出与已训练好的teacher网络输出经过软化的结果之间差距
teacher_outputs = teach_model(inputs.float())

# T和alpha是两个超参数,取法对结果影响很大
T = 2       # 一般取2,10,20
alpha = 0.5 # 一般取0.5,0.9,0.95

# student网络输出软化后结果
# log_softmax与softmax没有本质的区别,只不过log_softmax会得到一个正值的loss结果。
outputs_S = F.log_softmax(teacher_outputs/T,dim=1)
# teacher网络输出软化后结果
outputs_T = F.softmax(teacher_outputs/T,dim=1)

# 蒸馏损失采用的是KL散度损失函数
criterion2 = nn.KLDivLoss()
loss2 = criterion2(outputs_S,outputs_T)*T*T

#用参数alpha综合损失结果
loss = loss1*(1-alpha) + loss2*alpha

参考文献
深度学习方法(十五):知识蒸馏(Distilling the Knowledge in a Neural Network)
基于知识蒸馏Knowledge Distillation模型压缩pytorch实现

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-01 14:30:37  更:2021-08-01 14:32:20 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/17 20:41:33-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码