如有错误,恳请指出。
在之前的知识蒸馏的例子中,我搭建了一个大的神经教师网络来对小的神经学生网络进行蒸馏。但是,看完那篇博客的朋友可能知道,其实效果不是很好,最后的效果甚至还不如直接训练学生网络,这成为了一时的心结。ps:这里贴上之前的两篇文章,如下所示:
1. 《知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍》 2. 《【29】知识蒸馏(knowledge distillation)测试以及利用可学习参数辅助知识蒸馏训练Student模型》
为此,我想,一个不够,那就再加一个。而且加的是卷积网络来称为另外的一个教师网络,这样又有卷积的教师网络又有普通的神经网络,可能会带来一点好的提升与性能。
下面话不多说,直接贴上jupyter notebook的测试过程:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import alexnet
import random
import tqdm
import numpy as np
from thop import profile
def SetSeed(SEED=42):
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
SetSeed()
1. 定义网络
class CNN_Teancher(nn.Module):
def __init__(self, num_classes=10):
super(CNN_Teancher, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(192),
nn.ReLU(),
nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=1)
)
self.classifier = nn.Sequential(
nn.Linear(256, num_classes),
nn.Dropout(p=0.5)
)
def forward(self, x):
x = self.model(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
class Teacher(nn.Module):
def __init__(self, num_classes=10):
super(Teacher, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 1200),
nn.Dropout(p=0.5),
nn.ReLU(),
nn.Linear(1200, 1200),
nn.Dropout(p=0.5),
nn.ReLU(),
nn.Linear(1200, num_classes)
)
def forward(self, x):
x = x.view(-1, 784)
return self.model(x)
class Student(nn.Module):
def __init__(self, num_classes=10):
super(Student, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 20),
nn.ReLU(),
nn.Linear(20, num_classes)
)
def forward(self, x):
x = x.view(-1, 784)
return self.model(x)
2. 数据集配置
epoch_size = 10
batch_size = 128
learning_rate = 1e-4
train_data = datasets.MNIST('./', train=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_data = datasets.MNIST('./', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
device = torch.device('cuda')
t_model = Teacher().to(device)
s_model = Student().to(device)
criterion = nn.CrossEntropyLoss().to(device)
服务器上没有mnist数据集,所以这里我得下载一下:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
/home/fs/anaconda3/envs/yolo/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
3. 训练与测试
def train_one_epoch(model, criterion, optimizer, dataloader):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (image, targets) in enumerate(dataloader):
image, targets = image.to(device), targets.to(device)
outputs = model(image)
loss = criterion(outputs, targets)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx % int(len(dataloader)/5) == 0:
train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
.format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
correct / total, correct, total)
print(train_info)
def validate(model, criterion, dataloader):
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
test_info = "Test ==> loss:{}, acc:{} ({}/{})"\
.format(test_loss/len(dataloader), correct/total, correct, total)
print(test_info)
3.1 训练教师神经网络
t_optimizer = optim.Adam(t_model.parameters(), lr=learning_rate)
for epoch in range(epoch_size):
print("[Epoch:{}]".format(epoch))
train_one_epoch(t_model, criterion, t_optimizer, train_loader)
validate(t_model, criterion, test_loader)
torch.save(t_model.state_dict(), "t_model.mdl")
输出结果:
[Epoch:0]
batch:0/469, loss:2.3867623805999756, acc:0.0703125 (9/128)
batch:93/469, loss:0.9396106711727508, acc:0.7176695478723404 (8635/12032)
batch:186/469, loss:0.6751111858190699, acc:0.7959558823529411 (19052/23936)
batch:279/469, loss:0.560909423657826, acc:0.8306919642857142 (29772/35840)
batch:372/469, loss:0.4910400741141859, acc:0.8519604557640751 (40676/47744)
batch:465/469, loss:0.44459033963698175, acc:0.8668689645922747 (51707/59648)
Test ==> loss:0.18145106082098394, acc:0.9461 (9461/10000)
[Epoch:1]
batch:0/469, loss:0.20238660275936127, acc:0.9296875 (119/128)
batch:93/469, loss:0.2128744441619579, acc:0.9353390957446809 (11254/12032)
batch:186/469, loss:0.21077364432142381, acc:0.9370404411764706 (22429/23936)
batch:279/469, loss:0.20751248164368527, acc:0.937890625 (33614/35840)
batch:372/469, loss:0.20383650300170397, acc:0.938861427613941 (44825/47744)
batch:465/469, loss:0.19661776378526963, acc:0.9409368293991416 (56125/59648)
Test ==> loss:0.1174718544146494, acc:0.9613 (9613/10000)
[Epoch:2]
batch:0/469, loss:0.14401938021183014, acc:0.9296875 (119/128)
batch:93/469, loss:0.15022155472097246, acc:0.9551196808510638 (11492/12032)
batch:186/469, loss:0.15136209516761137, acc:0.9552139037433155 (22864/23936)
batch:279/469, loss:0.14719437270292213, acc:0.9559709821428571 (34262/35840)
batch:372/469, loss:0.14265749993017468, acc:0.9573349530831099 (45707/47744)
batch:465/469, loss:0.14064400804592816, acc:0.9580036212446352 (57143/59648)
Test ==> loss:0.09666164789961863, acc:0.9707 (9707/10000)
[Epoch:3]
batch:0/469, loss:0.0973980501294136, acc:0.96875 (124/128)
batch:93/469, loss:0.12336844725019121, acc:0.9634308510638298 (11592/12032)
batch:186/469, loss:0.11866759833566008, acc:0.9649481951871658 (23097/23936)
batch:279/469, loss:0.11617265337013773, acc:0.9653738839285714 (34599/35840)
batch:372/469, loss:0.114620619051498, acc:0.9653359584450402 (46089/47744)
batch:465/469, loss:0.11367744497571125, acc:0.9656484710300429 (57599/59648)
Test ==> loss:0.08241578724966207, acc:0.9736 (9736/10000)
[Epoch:4]
batch:0/469, loss:0.1517380326986313, acc:0.953125 (122/128)
batch:93/469, loss:0.09908725943495618, acc:0.9703291223404256 (11675/12032)
batch:186/469, loss:0.1005439937792041, acc:0.9687917780748663 (23189/23936)
batch:279/469, loss:0.09571810397984726, acc:0.9707868303571429 (34793/35840)
batch:372/469, loss:0.09503172322029083, acc:0.9709701742627346 (46358/47744)
batch:465/469, loss:0.0932354472793415, acc:0.9714491684549357 (57945/59648)
Test ==> loss:0.07466217042005892, acc:0.9777 (9777/10000)
[Epoch:5]
batch:0/469, loss:0.11435826867818832, acc:0.9453125 (121/128)
batch:93/469, loss:0.07882857798261846, acc:0.9756482712765957 (11739/12032)
batch:186/469, loss:0.07980489694976552, acc:0.9756433823529411 (23353/23936)
batch:279/469, loss:0.07940430943854153, acc:0.9759486607142858 (34978/35840)
batch:372/469, loss:0.07858294055824624, acc:0.9759341487935657 (46595/47744)
batch:465/469, loss:0.07859047989924578, acc:0.9759589592274678 (58214/59648)
Test ==> loss:0.06638586930812726, acc:0.9793 (9793/10000)
[Epoch:6]
batch:0/469, loss:0.041755419224500656, acc:0.9921875 (127/128)
batch:93/469, loss:0.072672658381944, acc:0.9777260638297872 (11764/12032)
batch:186/469, loss:0.07089516308337929, acc:0.9781918449197861 (23414/23936)
batch:279/469, loss:0.07038291455579124, acc:0.9781808035714286 (35058/35840)
batch:372/469, loss:0.07037656083852852, acc:0.9786988941018767 (46727/47744)
batch:465/469, loss:0.07013101020604053, acc:0.9785743025751072 (58370/59648)
Test ==> loss:0.06481085321276531, acc:0.9798 (9798/10000)
[Epoch:7]
batch:0/469, loss:0.08536697179079056, acc:0.9765625 (125/128)
batch:93/469, loss:0.05673933652368315, acc:0.9819647606382979 (11815/12032)
batch:186/469, loss:0.05650453631801003, acc:0.9825785427807486 (23519/23936)
batch:279/469, loss:0.05738637866951259, acc:0.9823381696428571 (35207/35840)
batch:372/469, loss:0.058486481106043584, acc:0.9817987600536193 (46875/47744)
batch:465/469, loss:0.06104880005676827, acc:0.9808040504291845 (58503/59648)
Test ==> loss:0.057878647279583764, acc:0.9822 (9822/10000)
[Epoch:8]
batch:0/469, loss:0.14456196129322052, acc:0.953125 (122/128)
batch:93/469, loss:0.05808326327539188, acc:0.9830452127659575 (11828/12032)
batch:186/469, loss:0.054870715339912134, acc:0.9829127673796791 (23527/23936)
batch:279/469, loss:0.05355608092754015, acc:0.9828125 (35224/35840)
batch:372/469, loss:0.052837840473683846, acc:0.9833277479892761 (46948/47744)
batch:465/469, loss:0.052799447139829016, acc:0.9832517435622318 (58649/59648)
Test ==> loss:0.05802279706054096, acc:0.9823 (9823/10000)
[Epoch:9]
batch:0/469, loss:0.02145610749721527, acc:0.9921875 (127/128)
batch:93/469, loss:0.04951975690795386, acc:0.9847074468085106 (11848/12032)
batch:186/469, loss:0.04819783547464858, acc:0.9857536764705882 (23595/23936)
batch:279/469, loss:0.04857771968735116, acc:0.9849051339285714 (35299/35840)
batch:372/469, loss:0.04838609968838919, acc:0.9847729557640751 (47017/47744)
batch:465/469, loss:0.04821777474155146, acc:0.9848779506437768 (58746/59648)
Test ==> loss:0.051948129977512206, acc:0.9838 (9838/10000)
3.2 训练教师卷积网络
ct_model = CNN_Teancher().to(device)
ct_optimizer = optim.Adam(ct_model.parameters(), lr=learning_rate)
for epoch in range(epoch_size):
print("[Epoch:{}]".format(epoch))
train_one_epoch(ct_model, criterion, ct_optimizer, train_loader)
validate(ct_model, criterion, test_loader)
torch.save(ct_model.state_dict(), "ct_model.mdl")
输出结果:
[Epoch:0]
batch:0/469, loss:2.424455404281616, acc:0.0859375 (11/128)
batch:93/469, loss:1.961557946306594, acc:0.3346908244680851 (4027/12032)
batch:186/469, loss:1.7907984537236832, acc:0.3976019385026738 (9517/23936)
batch:279/469, loss:1.6823635867663793, acc:0.4309151785714286 (15444/35840)
batch:372/469, loss:1.5959120215423626, acc:0.4554289544235925 (21744/47744)
batch:465/469, loss:1.5276136935524673, acc:0.4724718347639485 (28182/59648)
Test ==> loss:0.7926474839826173, acc:0.93 (9300/10000)
[Epoch:1]
batch:0/469, loss:1.2736332416534424, acc:0.546875 (70/128)
batch:93/469, loss:1.2054280004602798, acc:0.5482047872340425 (6596/12032)
batch:186/469, loss:1.172185863721817, acc:0.5583639705882353 (13365/23936)
batch:279/469, loss:1.157212756574154, acc:0.559375 (20048/35840)
batch:372/469, loss:1.1450653734539533, acc:0.5607406166219839 (26772/47744)
batch:465/469, loss:1.1353451458681294, acc:0.5617791040772532 (33509/59648)
Test ==> loss:0.47211516431615325, acc:0.9597 (9597/10000)
[Epoch:2]
batch:0/469, loss:1.0209004878997803, acc:0.59375 (76/128)
batch:93/469, loss:1.0488138236898057, acc:0.58203125 (7003/12032)
batch:186/469, loss:1.0366387905921528, acc:0.5831801470588235 (13959/23936)
batch:279/469, loss:1.0379441861595426, acc:0.5789620535714286 (20750/35840)
batch:372/469, loss:1.03275924745258, acc:0.5796539879356568 (27675/47744)
batch:465/469, loss:1.0288827585560059, acc:0.5789800160944206 (34535/59648)
Test ==> loss:0.332201937331429, acc:0.9632 (9632/10000)
[Epoch:3]
batch:0/469, loss:0.9517718553543091, acc:0.6171875 (79/128)
batch:93/469, loss:0.9978312501247893, acc:0.5810339095744681 (6991/12032)
batch:186/469, loss:1.0023587601070099, acc:0.5776654411764706 (13827/23936)
batch:279/469, loss:0.9987784639000893, acc:0.5787667410714286 (20743/35840)
batch:372/469, loss:0.9959872414535236, acc:0.5788790214477212 (27638/47744)
batch:465/469, loss:0.991175599492159, acc:0.5813606491416309 (34677/59648)
Test ==> loss:0.2696448630547222, acc:0.9666 (9666/10000)
[Epoch:4]
batch:0/469, loss:0.7853888869285583, acc:0.6875 (88/128)
batch:93/469, loss:0.9590603220970073, acc:0.5880152925531915 (7075/12032)
batch:186/469, loss:0.9617624403958652, acc:0.5840157085561497 (13979/23936)
batch:279/469, loss:0.9641852915287018, acc:0.5849330357142857 (20964/35840)
batch:372/469, loss:0.9647335381354467, acc:0.5839896112600537 (27882/47744)
batch:465/469, loss:0.9653625727467271, acc:0.583506571888412 (34805/59648)
Test ==> loss:0.2166517706988733, acc:0.9722 (9722/10000)
[Epoch:5]
batch:0/469, loss:0.8839341402053833, acc:0.59375 (76/128)
batch:93/469, loss:0.9494094988133045, acc:0.5852726063829787 (7042/12032)
batch:186/469, loss:0.9495056312989424, acc:0.5849348262032086 (14001/23936)
batch:279/469, loss:0.9469595074653625, acc:0.5874720982142857 (21055/35840)
batch:372/469, loss:0.9451027883281656, acc:0.5875502680965148 (28052/47744)
batch:465/469, loss:0.9445154733412255, acc:0.5877648873390557 (35059/59648)
Test ==> loss:0.1983407870689525, acc:0.9739 (9739/10000)
[Epoch:6]
batch:0/469, loss:0.9596455693244934, acc:0.59375 (76/128)
batch:93/469, loss:0.9371558269287678, acc:0.5885970744680851 (7082/12032)
batch:186/469, loss:0.933486777193406, acc:0.5869401737967914 (14049/23936)
batch:279/469, loss:0.9356063387223652, acc:0.5876953125 (21063/35840)
batch:372/469, loss:0.9321824889080774, acc:0.5887231903485255 (28108/47744)
batch:465/469, loss:0.9308745358379102, acc:0.588955203862661 (35130/59648)
Test ==> loss:0.15833796409866477, acc:0.975 (9750/10000)
[Epoch:7]
batch:0/469, loss:0.9392691254615784, acc:0.609375 (78/128)
batch:93/469, loss:0.9271161625994012, acc:0.5857712765957447 (7048/12032)
batch:186/469, loss:0.9292944392418478, acc:0.5856450534759359 (14018/23936)
batch:279/469, loss:0.9243043863347599, acc:0.5865234375 (21021/35840)
batch:372/469, loss:0.9234551397469344, acc:0.5877178284182306 (28060/47744)
batch:465/469, loss:0.9230064044950346, acc:0.5870775214592274 (35018/59648)
Test ==> loss:0.14993384857720968, acc:0.9756 (9756/10000)
[Epoch:8]
batch:0/469, loss:0.963413655757904, acc:0.5703125 (73/128)
batch:93/469, loss:0.9203153658420482, acc:0.5882646276595744 (7078/12032)
batch:186/469, loss:0.9135349348267132, acc:0.5940006684491979 (14218/23936)
batch:279/469, loss:0.9185146927833557, acc:0.5903459821428572 (21158/35840)
batch:372/469, loss:0.9155782517095673, acc:0.5907548592493298 (28205/47744)
batch:465/469, loss:0.9142030574733095, acc:0.5917884924892703 (35299/59648)
Test ==> loss:0.142572170005569, acc:0.9755 (9755/10000)
[Epoch:9]
batch:0/469, loss:0.8696589469909668, acc:0.6171875 (79/128)
batch:93/469, loss:0.9138033472477122, acc:0.5925864361702128 (7130/12032)
batch:186/469, loss:0.9100717417696581, acc:0.5940842245989305 (14220/23936)
batch:279/469, loss:0.9118563358272825, acc:0.5920479910714286 (21219/35840)
batch:372/469, loss:0.9089114867330557, acc:0.592912198391421 (28308/47744)
batch:465/469, loss:0.9042545945859263, acc:0.5938170600858369 (35420/59648)
Test ==> loss:0.13352179668749434, acc:0.9753 (9753/10000)
3.3 训练学生神经网络
s_model = Student().to(device)
s_optimizer = optim.Adam(s_model.parameters(), lr=learning_rate)
for epoch in range(epoch_size):
print("[Epoch:{}]".format(epoch))
train_one_epoch(s_model, criterion, s_optimizer, train_loader)
validate(s_model, criterion, test_loader)
输出结果:
[Epoch:0]
batch:0/469, loss:2.357004404067993, acc:0.1875 (24/128)
batch:93/469, loss:1.7883413776438286, acc:0.4625166223404255 (5565/12032)
batch:186/469, loss:1.4309818958216172, acc:0.5944184491978609 (14228/23936)
batch:279/469, loss:1.1971364739750112, acc:0.6667131696428571 (23895/35840)
batch:372/469, loss:1.0451311646453816, acc:0.7116915214477212 (33979/47744)
batch:465/469, loss:0.9389211163884069, acc:0.7416845493562232 (44240/59648)
Test ==> loss:0.46114580276646194, acc:0.8833 (8833/10000)
[Epoch:1]
batch:0/469, loss:0.5253818035125732, acc:0.84375 (108/128)
batch:93/469, loss:0.45607277180286165, acc:0.8744182180851063 (10521/12032)
batch:186/469, loss:0.44302085488237797, acc:0.8799715909090909 (21063/23936)
batch:279/469, loss:0.42559995853475163, acc:0.8859095982142857 (31751/35840)
batch:372/469, loss:0.4132335932261183, acc:0.8879649798927614 (42395/47744)
batch:465/469, loss:0.4041708303943212, acc:0.8905914699570815 (53122/59648)
Test ==> loss:0.34390119095391863, acc:0.9063 (9063/10000)
[Epoch:2]
batch:0/469, loss:0.27973929047584534, acc:0.9140625 (117/128)
batch:93/469, loss:0.3445089725737876, acc:0.9063331117021277 (10905/12032)
batch:186/469, loss:0.33749938792085904, acc:0.9072526737967914 (21716/23936)
batch:279/469, loss:0.3341189743684871, acc:0.9078125 (32536/35840)
batch:372/469, loss:0.33311483519646184, acc:0.9080722184986595 (43355/47744)
batch:465/469, loss:0.32733086157638114, acc:0.9092509388412017 (54235/59648)
Test ==> loss:0.29618216881269144, acc:0.9179 (9179/10000)
[Epoch:3]
batch:0/469, loss:0.3462095856666565, acc:0.8984375 (115/128)
batch:93/469, loss:0.3092869266550592, acc:0.9137300531914894 (10994/12032)
batch:186/469, loss:0.30020128891748543, acc:0.9171958556149733 (21954/23936)
batch:279/469, loss:0.2974837499537638, acc:0.9169084821428571 (32862/35840)
batch:372/469, loss:0.29631629350517774, acc:0.9173927613941019 (43800/47744)
batch:465/469, loss:0.29146103778366367, acc:0.9190584763948498 (54820/59648)
Test ==> loss:0.2715968393449542, acc:0.9234 (9234/10000)
[Epoch:4]
batch:0/469, loss:0.3141135573387146, acc:0.890625 (114/128)
batch:93/469, loss:0.2594475950649444, acc:0.9300199468085106 (11190/12032)
batch:186/469, loss:0.26877123388377105, acc:0.9263870320855615 (22174/23936)
batch:279/469, loss:0.27071290723979474, acc:0.9251674107142858 (33158/35840)
batch:372/469, loss:0.2679654088560442, acc:0.9255403820375335 (44189/47744)
batch:465/469, loss:0.2687629308119864, acc:0.9247250536480687 (55158/59648)
Test ==> loss:0.25750901823556877, acc:0.9282 (9282/10000)
[Epoch:5]
batch:0/469, loss:0.17168357968330383, acc:0.96875 (124/128)
batch:93/469, loss:0.2522462420165539, acc:0.9288563829787234 (11176/12032)
batch:186/469, loss:0.2581011697171844, acc:0.9273897058823529 (22198/23936)
batch:279/469, loss:0.2550680588132569, acc:0.9282924107142857 (33270/35840)
batch:372/469, loss:0.2549391207722173, acc:0.9280956769436998 (44311/47744)
batch:465/469, loss:0.25167093422791476, acc:0.9286145386266095 (55390/59648)
Test ==> loss:0.24905249646192865, acc:0.9306 (9306/10000)
[Epoch:6]
batch:0/469, loss:0.26825278997421265, acc:0.90625 (116/128)
batch:93/469, loss:0.2381658841796378, acc:0.9311003989361702 (11203/12032)
batch:186/469, loss:0.23873208303821278, acc:0.9309408422459893 (22283/23936)
batch:279/469, loss:0.24032189510762691, acc:0.9308035714285714 (33360/35840)
batch:372/469, loss:0.23933973307184495, acc:0.9313421581769437 (44466/47744)
batch:465/469, loss:0.23892121092124557, acc:0.931749597639485 (55577/59648)
Test ==> loss:0.23393861119505727, acc:0.9336 (9336/10000)
[Epoch:7]
batch:0/469, loss:0.3488951325416565, acc:0.921875 (118/128)
batch:93/469, loss:0.23248760267458063, acc:0.9334275265957447 (11231/12032)
batch:186/469, loss:0.22965810611286266, acc:0.9336981951871658 (22349/23936)
batch:279/469, loss:0.22631887934569803, acc:0.9351841517857142 (33517/35840)
batch:372/469, loss:0.22711910750846762, acc:0.9349028150134048 (44636/47744)
batch:465/469, loss:0.2285995162991495, acc:0.9348175965665236 (55760/59648)
Test ==> loss:0.22850063462046127, acc:0.9344 (9344/10000)
[Epoch:8]
batch:0/469, loss:0.2651296854019165, acc:0.90625 (116/128)
batch:93/469, loss:0.22513935580215555, acc:0.9361702127659575 (11264/12032)
batch:186/469, loss:0.22959564474814717, acc:0.9356199866310161 (22395/23936)
batch:279/469, loss:0.2243422551612769, acc:0.9376116071428572 (33604/35840)
batch:372/469, loss:0.2194721238742565, acc:0.9387567024128687 (44820/47744)
batch:465/469, loss:0.22032797501258583, acc:0.9381538358369099 (55959/59648)
Test ==> loss:0.2206512165220478, acc:0.9351 (9351/10000)
[Epoch:9]
batch:0/469, loss:0.2756182849407196, acc:0.921875 (118/128)
batch:93/469, loss:0.2252512621752759, acc:0.9366688829787234 (11270/12032)
batch:186/469, loss:0.22236462188436384, acc:0.936956885026738 (22427/23936)
batch:279/469, loss:0.2201754414077316, acc:0.9374441964285715 (33598/35840)
batch:372/469, loss:0.21374527243123298, acc:0.9391756032171582 (44840/47744)
batch:465/469, loss:0.21209755683789439, acc:0.9397465128755365 (56054/59648)
Test ==> loss:0.21237921054604686, acc:0.9375 (9375/10000)
现在有两个教师网络,一个是正确率为0.9838的神经网络,一个是正确率为0.9753的卷积神经网络;如果简单的训练学生网络,正确率为0.9375 ,那么现在考虑的内容是两个教师网络对学生网络进行训练。
3.4 多教师联合蒸馏训练
下面重新构建多教师网络知识蒸馏的训练函数:
def train_one_epoch_kd(s_model, t_model, ct_model, hard_loss, soft_loss, optimizer, parms_optimizer, dataloader):
s_model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (image, targets) in enumerate(dataloader):
image, targets = image.to(device), targets.to(device)
with torch.no_grad():
teacher_1_preds = t_model(image)
teacher_2_preds = ct_model(image)
student_preds = s_model(image)
student_loss = hard_loss(student_preds, targets)
ditillation_1_loss = soft_loss(
F.softmax(student_preds / temp, dim=1),
F.softmax(teacher_1_preds / temp, dim=1)
)
ditillation_2_loss = soft_loss(
F.softmax(student_preds / temp, dim=1),
F.softmax(teacher_2_preds / temp, dim=1)
)
loss = 0.2 * ditillation_1_loss + \
0.2 * ditillation_2_loss + \
0.6 * student_loss
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = student_preds.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx % int(len(dataloader)/5) == 0:
train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
.format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
correct / total, correct, total)
print(train_info)
temp = 7
alpha = nn.Parameter(torch.tensor(0.35), requires_grad=True)
gama = nn.Parameter(torch.tensor(0.35), requires_grad=True)
params = [alpha, gama]
print(params)
hard_loss = nn.CrossEntropyLoss()
soft_loss = nn.KLDivLoss(reduction="batchmean")
kd_model = Student().to(device)
kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate)
parms_optimizer = optim.Adam(params, lr=3e-5)
for epoch in range(epoch_size):
print("[Epoch:{}]\n".format(epoch),
"[weight | alpha:{},gama:{},st:{}]"
.format(alpha, gama, 1-alpha-gama))
train_one_epoch_kd(kd_model, t_model, ct_model, hard_loss, soft_loss, kd_optimizer, parms_optimizer, train_loader)
validate(kd_model, criterion, test_loader)
输出结果:
[Parameter containing:
tensor(0.2000, requires_grad=True), Parameter containing:
tensor(0.2000, requires_grad=True)]
[Epoch:0]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:1.070181131362915, acc:0.1484375 (19/128)
batch:93/469, loss:0.7189208091573512, acc:0.3718417553191489 (4474/12032)
batch:186/469, loss:0.4570929299701344, acc:0.5301637700534759 (12690/23936)
batch:279/469, loss:0.2827999550317015, acc:0.6219587053571428 (22291/35840)
batch:372/469, loss:0.16214995130137527, acc:0.6787449731903485 (32406/47744)
batch:465/469, loss:0.07729014530458164, acc:0.7165034871244635 (42738/59648)
Test ==> loss:0.4550707121438618, acc:0.883 (8830/10000)
[Epoch:1]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.21972870826721191, acc:0.7890625 (101/128)
batch:93/469, loss:-0.3207302990745991, acc:0.8866356382978723 (10668/12032)
batch:186/469, loss:-0.3327708426006338, acc:0.8868231951871658 (21227/23936)
batch:279/469, loss:-0.34525985323957037, acc:0.8899832589285714 (31897/35840)
batch:372/469, loss:-0.3539618340797782, acc:0.891504691689008 (42564/47744)
batch:465/469, loss:-0.36126882240751784, acc:0.8931732832618026 (53276/59648)
Test ==> loss:0.320476306578781, acc:0.9118 (9118/10000)
[Epoch:2]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.42595481872558594, acc:0.8984375 (115/128)
batch:93/469, loss:-0.4049103111028671, acc:0.9043384308510638 (10881/12032)
batch:186/469, loss:-0.4099318265596176, acc:0.9061246657754011 (21689/23936)
batch:279/469, loss:-0.41254837651337894, acc:0.9076450892857143 (32530/35840)
batch:372/469, loss:-0.4177393170208458, acc:0.9097687667560321 (43436/47744)
batch:465/469, loss:-0.4201628497997579, acc:0.91037419527897 (54302/59648)
Test ==> loss:0.2765346652344812, acc:0.9203 (9203/10000)
[Epoch:3]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.47607481479644775, acc:0.9375 (120/128)
batch:93/469, loss:-0.4405997947175452, acc:0.9193816489361702 (11062/12032)
batch:186/469, loss:-0.44154514938114797, acc:0.9194100935828877 (22007/23936)
batch:279/469, loss:-0.44397371081369263, acc:0.9194754464285714 (32954/35840)
batch:372/469, loss:-0.4449578410179302, acc:0.919110254691689 (43882/47744)
batch:465/469, loss:-0.44588555446765965, acc:0.9195111319742489 (54847/59648)
Test ==> loss:0.25252004533628875, acc:0.9255 (9255/10000)
[Epoch:4]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.46281740069389343, acc:0.9296875 (119/128)
batch:93/469, loss:-0.4547458785645505, acc:0.921376329787234 (11086/12032)
batch:186/469, loss:-0.46090967977110714, acc:0.9243816844919787 (22126/23936)
batch:279/469, loss:-0.4611783997288772, acc:0.9245814732142857 (33137/35840)
batch:372/469, loss:-0.46058367931810845, acc:0.9252262064343163 (44174/47744)
batch:465/469, loss:-0.4612702508596903, acc:0.9251944742489271 (55186/59648)
Test ==> loss:0.23789097774255125, acc:0.9292 (9292/10000)
[Epoch:5]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.48149824142456055, acc:0.9375 (120/128)
batch:93/469, loss:-0.47065415629681123, acc:0.9292719414893617 (11181/12032)
batch:186/469, loss:-0.46971190995711054, acc:0.9272225935828877 (22194/23936)
batch:279/469, loss:-0.4715161386345114, acc:0.928515625 (33278/35840)
batch:372/469, loss:-0.472420722326069, acc:0.9290591487935657 (44357/47744)
batch:465/469, loss:-0.47325622069733336, acc:0.9294527896995708 (55440/59648)
Test ==> loss:0.22742461035900477, acc:0.9326 (9326/10000)
[Epoch:6]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.5146299600601196, acc:0.96875 (124/128)
batch:93/469, loss:-0.47914040120358165, acc:0.932845744680851 (11224/12032)
batch:186/469, loss:-0.4779144224317316, acc:0.931442179144385 (22295/23936)
batch:279/469, loss:-0.48003286900264874, acc:0.9315848214285715 (33388/35840)
batch:372/469, loss:-0.48055645328104973, acc:0.9321171246648794 (44503/47744)
batch:465/469, loss:-0.48185253974705805, acc:0.933074034334764 (55656/59648)
Test ==> loss:0.22472369784041296, acc:0.9336 (9336/10000)
[Epoch:7]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.44770073890686035, acc:0.90625 (116/128)
batch:93/469, loss:-0.4870964447234539, acc:0.9355053191489362 (11256/12032)
batch:186/469, loss:-0.48533328188294395, acc:0.9351186497326203 (22383/23936)
batch:279/469, loss:-0.48802188954183034, acc:0.9361607142857142 (33552/35840)
batch:372/469, loss:-0.4895784726251546, acc:0.9360338471849866 (44690/47744)
batch:465/469, loss:-0.4890587088759877, acc:0.9358570278969958 (55822/59648)
Test ==> loss:0.2101580535026291, acc:0.9362 (9362/10000)
[Epoch:8]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.4669145345687866, acc:0.9140625 (117/128)
batch:93/469, loss:-0.4905587165279591, acc:0.9385804521276596 (11293/12032)
batch:186/469, loss:-0.4956321018264893, acc:0.9393382352941176 (22484/23936)
batch:279/469, loss:-0.4955592581204006, acc:0.9392299107142857 (33662/35840)
batch:372/469, loss:-0.4950189735870259, acc:0.9388195375335121 (44823/47744)
batch:465/469, loss:-0.49565435664592383, acc:0.9390423819742489 (56012/59648)
Test ==> loss:0.20754675633167918, acc:0.9388 (9388/10000)
[Epoch:9]
[weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.5581856966018677, acc:0.984375 (126/128)
batch:93/469, loss:-0.5014005094132525, acc:0.9423204787234043 (11338/12032)
batch:186/469, loss:-0.49853202844048566, acc:0.9412182486631016 (22529/23936)
batch:279/469, loss:-0.5006332156913621, acc:0.9414341517857143 (33741/35840)
batch:372/469, loss:-0.5006915050441394, acc:0.9413329423592494 (44943/47744)
batch:465/469, loss:-0.5015091910126895, acc:0.9409703594420601 (56127/59648)
Test ==> loss:0.20188309327711032, acc:0.9411 (9411/10000)
最后,联合训练的效果为0.9411 ,要比单独训练学生模型的效果0.9375要好。
可以看见,经过合适的调参,联合蒸馏训练的效果确实比单独训练学生网络的效果要好。不过前提就是要调好参数。调参的问题是一个血泪的问题,要是调不好,其实并不能发挥出知识蒸馏这个方法的性能,有时往往会出现没有效果的情况,就是用了效果反而还下降了。
而且,这里我本来使用了两个可学习的权重参数分配给教师卷积网络与教师神经网络,但是效果其实并不会比固定权重的训练效果要来得好。经过测试,两个教师网络分配0.2左右的权重效果是最好的。
经过这次的总结,感叹调参是个技术活,深度学习确实像是炼丹一样。
参考资料:
1. 《知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍》 2. 《【29】知识蒸馏(knowledge distillation)测试以及利用可学习参数辅助知识蒸馏训练Student模型》
|