前言
许久没写博客,今天趁着假期最后一天,分享下今天看到的一篇关于自蒸馏的论文。
题目:Student Helping Teacher: Teacher Evolution via Self-Knowledge Distillation 地址:https://arxiv.org/abs/2110.00329 github:https://github.com/zhengli427/TESKD/
主要思路
之前的蒸馏方法基本都是teacher监督student或者多个student之间互相监督,BYOT论文中,作者将student拆分为多个block,每个block都单独接一个fc,计算celoss,相当于希望每个block都可以学到更加鲁棒的feature(当然,浅层的feature最终的分类效果肯定是相对比较差的)。
在TESKD ,也就是这篇博客的主角中,作者借鉴了BYOT以及FPN的结构,使得模型训练以一种自蒸馏的方式呈现出现,不同层级之间的feature也会进行融合并互相监督,融合之后的feature也会接avgpool以及fc,最终也是具有分类能力的,融合的这个过程进一步提升了feature的鲁棒性,也带来了更为优秀的分类结果
结构框图
下面是TESKD 自蒸馏算法的结构框图,这其中其实只有一个网络,也就是我们最终用于部署的网络,但是这里为了区分,还是将其称之为教师网络,T1~T4是区分出来的4个block,这对于ResNet等比较标准的网络来说都是比较好实现的(不同的Res stage印出来即可)。对于拿到的feature,使用下面右边的方法进行融合。具体地,
T
b
?
1
T_{b-1}
Tb?1?接上1x1卷积进行维度映射,S_b 接上2x2的上采样以及1x1 卷积(实际上是conv+bn+relu,下同)进行通道维度映射(都变成512维),然后再进行add与concat的操作,最后再接1x1 卷积进行fuse,得到
S
b
?
1
S_{b-1}
Sb?1?,感觉是右边的图,左边出来的写错了。整个操作流程类似于FPN。
弄完之后,S1~S3再接上avgpool(出来的feature用于计算feature loss)以及fc(出来的logits用于计算celoss以及kd loss)。整体结构比较清晰。
最终loss也是包含这几个部分:CELoss、KDLoss、FeatureLoss。
整个过程中,其实只有teacher model是我们用于infer的model,其他的S1~S3其实只是用于打辅助的,而且也没有使用任何的pretrain,因此可以归为self-distillation的范畴,思路确实比较有意思。
实验
相比于其他的方法,TESKD在cifar100上的优势要更加明显一些,比之前的蒸馏方法都要好
下面也给出了与BYOT的比较,毕竟二者也有很多相似的地方。一方面是因为更多的监督和级联信息,最终S1~S3以及最终的output精度都超过了BYOT;另一方面,由于特征融合,所以TESKD的浅层特征也包含了深层特征,所以精度优势相比于BYOT要更为明显。
好的蒸馏算法需要经得起ImageNet的考验,作者的实验也证实了这一点,不过提升相比于Cifar100数据集要小很多。其实蒸馏可以理解为正则化的一种思路,这个结论也是正常的,大数据集上,模型能力就相对有限了。
作者也设计了一些消融实验,看下具体是哪个部分影响最大,最终发现,蒸馏loss对于整体效果的提升还是最明显的,其他的在MFM模块中的设计也会带来一定的精度提升,但是不起主导作用。
结论
一种新的自蒸馏思路,包括之前的reviewKD等方法,其实或多或少都开始走feature merge + distillation的路子了,感觉可以从feature连接的角度,去进一步挖掘这种方法的潜力;当然,为了便利,真的要摆脱一个精度更高的教师网络吗?如果从最终的精度出发,其实还是有待商榷的,也欢迎大家讨论。
|