本文介绍知识蒸馏(Knowledge Distillation)。核心思想是通过迁移知识,从而通过训练好的大模型得到更加适合推理的小模型。
文章的核心思想就是提出用soft target 来辅助hard target 一起训练,而soft target 来自于大模型的预测输出:
1、训练大模型:先用hard target ,也就是正常的标签训练大模型。 2、计算soft target :利用训练好的大模型来计算soft target 。也就是大模型“软化后”再经过softmax 的输出。 3、训练小模型,在小模型的基础上再加一个额外的soft target 的损失函数,通过alpha来调节两个损失函数的比重。 4、预测时,将训练好的小模型按常规方式(右图)使用。
1. 网络结构
定义一个teacher 网络,由两个卷积、两个池化、一个全连接层组成。
class anNet_deep(nn.Module):
def __init__(self):
super(anNet_deep,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(64,64,3,1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU())
self.conv3 = nn.Sequential(
nn.Conv2d(64,128,3,1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU())
self.conv4 = nn.Sequential(
nn.Conv2d(128,128,3,1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU())
self.pooling1 = nn.Sequential(nn.MaxPool2d(2,stride=2))
self.fc = nn.Sequential(nn.Linear(6272,10))
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = self.pooling1(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.pooling1(x)
x = x.view(x.size()[0],-1)
x = self.fc(x)
return x
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
m.bias.data.zero_()
定义一个student 网络,由一个卷积层、池化层、全连接层构成。
class anNet(nn.Module):
def __init__(self):
super(anNet,self).__init__()
self.conv1 = nn.Conv2d(1,6,3)
self.pool1 = nn.MaxPool2d(2,1)
self.fc3 = nn.Linear(3750,10)
def forward(self,x):
x = self.conv1(x)
x = self.pool1(F.relu(x))
x = x.view(x.size()[0],-1)
x = self.fc3(x)
return x
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
m.bias.data.zero_()
2. 损失函数
知识蒸馏的关键是损失函数的设计,它包括普通的交叉熵损失和建立在soft target 基础上的损失。
hard target 包含的信息量(信息熵)很低,soft target 包含的信息量大,拥有不同类之间关系的信息。
比如,同时分类驴和马的时候,尽管某张图片是马,但是soft target 就不会像hard target 那样只有马的index 处的值为1,其余为0,而是在驴的部分也会有概率。
这样的好处是,这个图像可能更像驴,而不会去像汽车或者狗之类的,而这样的soft 信息存在于概率中,以及标签之间的高低相似性都存在于soft target 中。
但是如果soft target 是像这样的信息[0.98 0.01 0.01] ,就意义不大了,所以需要在softmax 中增加温度参数T (这个设置在最终训练完之后的推理中是不需要的)。增加softmax 后的蒸馏损失函数:
综合损失函数:
outputs = model(inputs.float())
criterion = nn.CrossEntropyLoss()
loss1 = criterion(outputs, labels)
teacher_outputs = teach_model(inputs.float())
T = 2
alpha = 0.5
outputs_S = F.log_softmax(teacher_outputs/T,dim=1)
outputs_T = F.softmax(teacher_outputs/T,dim=1)
criterion2 = nn.KLDivLoss()
loss2 = criterion2(outputs_S,outputs_T)*T*T
loss = loss1*(1-alpha) + loss2*alpha
参考文献 深度学习方法(十五):知识蒸馏(Distilling the Knowledge in a Neural Network) 基于知识蒸馏Knowledge Distillation模型压缩pytorch实现
|