IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【32】多教师网络进行联合蒸馏测试 -> 正文阅读

[人工智能]【32】多教师网络进行联合蒸馏测试


如有错误,恳请指出。



在之前的知识蒸馏的例子中,我搭建了一个大的神经教师网络来对小的神经学生网络进行蒸馏。但是,看完那篇博客的朋友可能知道,其实效果不是很好,最后的效果甚至还不如直接训练学生网络,这成为了一时的心结。ps:这里贴上之前的两篇文章,如下所示:

1. 《知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍》
2. 《【29】知识蒸馏(knowledge distillation)测试以及利用可学习参数辅助知识蒸馏训练Student模型》

为此,我想,一个不够,那就再加一个。而且加的是卷积网络来称为另外的一个教师网络,这样又有卷积的教师网络又有普通的神经网络,可能会带来一点好的提升与性能。

下面话不多说,直接贴上jupyter notebook的测试过程:

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import alexnet

import random
import tqdm
import numpy as np
from thop import profile
# 设置随机种子,保证结果可复现
def SetSeed(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

SetSeed()

1. 定义网络


#  定义一个cnn的教师网络(没有过多考虑参数,仿AlexNet进行搭建的)
class CNN_Teancher(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN_Teancher, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(256, num_classes),
            nn.Dropout(p=0.5)
        )
    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


# 定义教师网络与学生网络
class Teacher(nn.Module):
    def __init__(self, num_classes=10):
        super(Teacher, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1200),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(1200, 1200),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(1200, num_classes)
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)


# 定义学生网络
class Student(nn.Module):
    def __init__(self, num_classes=10):
        super(Student, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 20),
            nn.ReLU(),
            nn.Linear(20, num_classes)
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)

2. 数据集配置


# 设置超参数
epoch_size = 10
batch_size = 128
learning_rate = 1e-4

# 训练集下载
train_data = datasets.MNIST('./', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# 测试集下载
test_data = datasets.MNIST('./', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# 定义模型
device = torch.device('cuda')
t_model = Teacher().to(device)
s_model = Student().to(device)

# 定义优化器与损失
criterion = nn.CrossEntropyLoss().to(device)

服务器上没有mnist数据集,所以这里我得下载一下:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz



  0%|          | 0/9912422 [00:00<?, ?it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz



  0%|          | 0/4542 [00:00<?, ?it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



/home/fs/anaconda3/envs/yolo/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

3. 训练与测试


# 训练过程
def train_one_epoch(model, criterion, optimizer, dataloader):
    
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        outputs = model(image)
        loss = criterion(outputs, targets)
        train_loss += loss.item()
        # 反向更新训练
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 计算正确个数
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)

# 测试过程
def validate(model, criterion, dataloader):
    
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    test_info = "Test ==> loss:{}, acc:{} ({}/{})"\
          .format(test_loss/len(dataloader), correct/total, correct, total)
    print(test_info)

3.1 训练教师神经网络

# 定义教师网络的优化器
t_optimizer = optim.Adam(t_model.parameters(), lr=learning_rate)
# 训练教师网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(t_model, criterion, t_optimizer, train_loader)
    validate(t_model, criterion, test_loader)
# 训练好教师模型后,先保存教师网络的模型参数,有需要再重新导入即可
torch.save(t_model.state_dict(), "t_model.mdl")

输出结果:

[Epoch:0]
batch:0/469, loss:2.3867623805999756, acc:0.0703125 (9/128)
batch:93/469, loss:0.9396106711727508, acc:0.7176695478723404 (8635/12032)
batch:186/469, loss:0.6751111858190699, acc:0.7959558823529411 (19052/23936)
batch:279/469, loss:0.560909423657826, acc:0.8306919642857142 (29772/35840)
batch:372/469, loss:0.4910400741141859, acc:0.8519604557640751 (40676/47744)
batch:465/469, loss:0.44459033963698175, acc:0.8668689645922747 (51707/59648)
Test ==> loss:0.18145106082098394, acc:0.9461 (9461/10000)
[Epoch:1]
batch:0/469, loss:0.20238660275936127, acc:0.9296875 (119/128)
batch:93/469, loss:0.2128744441619579, acc:0.9353390957446809 (11254/12032)
batch:186/469, loss:0.21077364432142381, acc:0.9370404411764706 (22429/23936)
batch:279/469, loss:0.20751248164368527, acc:0.937890625 (33614/35840)
batch:372/469, loss:0.20383650300170397, acc:0.938861427613941 (44825/47744)
batch:465/469, loss:0.19661776378526963, acc:0.9409368293991416 (56125/59648)
Test ==> loss:0.1174718544146494, acc:0.9613 (9613/10000)
[Epoch:2]
batch:0/469, loss:0.14401938021183014, acc:0.9296875 (119/128)
batch:93/469, loss:0.15022155472097246, acc:0.9551196808510638 (11492/12032)
batch:186/469, loss:0.15136209516761137, acc:0.9552139037433155 (22864/23936)
batch:279/469, loss:0.14719437270292213, acc:0.9559709821428571 (34262/35840)
batch:372/469, loss:0.14265749993017468, acc:0.9573349530831099 (45707/47744)
batch:465/469, loss:0.14064400804592816, acc:0.9580036212446352 (57143/59648)
Test ==> loss:0.09666164789961863, acc:0.9707 (9707/10000)
[Epoch:3]
batch:0/469, loss:0.0973980501294136, acc:0.96875 (124/128)
batch:93/469, loss:0.12336844725019121, acc:0.9634308510638298 (11592/12032)
batch:186/469, loss:0.11866759833566008, acc:0.9649481951871658 (23097/23936)
batch:279/469, loss:0.11617265337013773, acc:0.9653738839285714 (34599/35840)
batch:372/469, loss:0.114620619051498, acc:0.9653359584450402 (46089/47744)
batch:465/469, loss:0.11367744497571125, acc:0.9656484710300429 (57599/59648)
Test ==> loss:0.08241578724966207, acc:0.9736 (9736/10000)
[Epoch:4]
batch:0/469, loss:0.1517380326986313, acc:0.953125 (122/128)
batch:93/469, loss:0.09908725943495618, acc:0.9703291223404256 (11675/12032)
batch:186/469, loss:0.1005439937792041, acc:0.9687917780748663 (23189/23936)
batch:279/469, loss:0.09571810397984726, acc:0.9707868303571429 (34793/35840)
batch:372/469, loss:0.09503172322029083, acc:0.9709701742627346 (46358/47744)
batch:465/469, loss:0.0932354472793415, acc:0.9714491684549357 (57945/59648)
Test ==> loss:0.07466217042005892, acc:0.9777 (9777/10000)
[Epoch:5]
batch:0/469, loss:0.11435826867818832, acc:0.9453125 (121/128)
batch:93/469, loss:0.07882857798261846, acc:0.9756482712765957 (11739/12032)
batch:186/469, loss:0.07980489694976552, acc:0.9756433823529411 (23353/23936)
batch:279/469, loss:0.07940430943854153, acc:0.9759486607142858 (34978/35840)
batch:372/469, loss:0.07858294055824624, acc:0.9759341487935657 (46595/47744)
batch:465/469, loss:0.07859047989924578, acc:0.9759589592274678 (58214/59648)
Test ==> loss:0.06638586930812726, acc:0.9793 (9793/10000)
[Epoch:6]
batch:0/469, loss:0.041755419224500656, acc:0.9921875 (127/128)
batch:93/469, loss:0.072672658381944, acc:0.9777260638297872 (11764/12032)
batch:186/469, loss:0.07089516308337929, acc:0.9781918449197861 (23414/23936)
batch:279/469, loss:0.07038291455579124, acc:0.9781808035714286 (35058/35840)
batch:372/469, loss:0.07037656083852852, acc:0.9786988941018767 (46727/47744)
batch:465/469, loss:0.07013101020604053, acc:0.9785743025751072 (58370/59648)
Test ==> loss:0.06481085321276531, acc:0.9798 (9798/10000)
[Epoch:7]
batch:0/469, loss:0.08536697179079056, acc:0.9765625 (125/128)
batch:93/469, loss:0.05673933652368315, acc:0.9819647606382979 (11815/12032)
batch:186/469, loss:0.05650453631801003, acc:0.9825785427807486 (23519/23936)
batch:279/469, loss:0.05738637866951259, acc:0.9823381696428571 (35207/35840)
batch:372/469, loss:0.058486481106043584, acc:0.9817987600536193 (46875/47744)
batch:465/469, loss:0.06104880005676827, acc:0.9808040504291845 (58503/59648)
Test ==> loss:0.057878647279583764, acc:0.9822 (9822/10000)
[Epoch:8]
batch:0/469, loss:0.14456196129322052, acc:0.953125 (122/128)
batch:93/469, loss:0.05808326327539188, acc:0.9830452127659575 (11828/12032)
batch:186/469, loss:0.054870715339912134, acc:0.9829127673796791 (23527/23936)
batch:279/469, loss:0.05355608092754015, acc:0.9828125 (35224/35840)
batch:372/469, loss:0.052837840473683846, acc:0.9833277479892761 (46948/47744)
batch:465/469, loss:0.052799447139829016, acc:0.9832517435622318 (58649/59648)
Test ==> loss:0.05802279706054096, acc:0.9823 (9823/10000)
[Epoch:9]
batch:0/469, loss:0.02145610749721527, acc:0.9921875 (127/128)
batch:93/469, loss:0.04951975690795386, acc:0.9847074468085106 (11848/12032)
batch:186/469, loss:0.04819783547464858, acc:0.9857536764705882 (23595/23936)
batch:279/469, loss:0.04857771968735116, acc:0.9849051339285714 (35299/35840)
batch:372/469, loss:0.04838609968838919, acc:0.9847729557640751 (47017/47744)
batch:465/469, loss:0.04821777474155146, acc:0.9848779506437768 (58746/59648)
Test ==> loss:0.051948129977512206, acc:0.9838 (9838/10000)

3.2 训练教师卷积网络

# 构建卷积教师网络
ct_model = CNN_Teancher().to(device)
# 定义cnn教师网络的优化器
ct_optimizer = optim.Adam(ct_model.parameters(), lr=learning_rate)
# 训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(ct_model, criterion, ct_optimizer, train_loader)
    validate(ct_model, criterion, test_loader)
# 训练好教师模型后,先保存教师网络的模型参数,有需要再重新导入即可
torch.save(ct_model.state_dict(), "ct_model.mdl")

输出结果:

[Epoch:0]
batch:0/469, loss:2.424455404281616, acc:0.0859375 (11/128)
batch:93/469, loss:1.961557946306594, acc:0.3346908244680851 (4027/12032)
batch:186/469, loss:1.7907984537236832, acc:0.3976019385026738 (9517/23936)
batch:279/469, loss:1.6823635867663793, acc:0.4309151785714286 (15444/35840)
batch:372/469, loss:1.5959120215423626, acc:0.4554289544235925 (21744/47744)
batch:465/469, loss:1.5276136935524673, acc:0.4724718347639485 (28182/59648)
Test ==> loss:0.7926474839826173, acc:0.93 (9300/10000)
[Epoch:1]
batch:0/469, loss:1.2736332416534424, acc:0.546875 (70/128)
batch:93/469, loss:1.2054280004602798, acc:0.5482047872340425 (6596/12032)
batch:186/469, loss:1.172185863721817, acc:0.5583639705882353 (13365/23936)
batch:279/469, loss:1.157212756574154, acc:0.559375 (20048/35840)
batch:372/469, loss:1.1450653734539533, acc:0.5607406166219839 (26772/47744)
batch:465/469, loss:1.1353451458681294, acc:0.5617791040772532 (33509/59648)
Test ==> loss:0.47211516431615325, acc:0.9597 (9597/10000)
[Epoch:2]
batch:0/469, loss:1.0209004878997803, acc:0.59375 (76/128)
batch:93/469, loss:1.0488138236898057, acc:0.58203125 (7003/12032)
batch:186/469, loss:1.0366387905921528, acc:0.5831801470588235 (13959/23936)
batch:279/469, loss:1.0379441861595426, acc:0.5789620535714286 (20750/35840)
batch:372/469, loss:1.03275924745258, acc:0.5796539879356568 (27675/47744)
batch:465/469, loss:1.0288827585560059, acc:0.5789800160944206 (34535/59648)
Test ==> loss:0.332201937331429, acc:0.9632 (9632/10000)
[Epoch:3]
batch:0/469, loss:0.9517718553543091, acc:0.6171875 (79/128)
batch:93/469, loss:0.9978312501247893, acc:0.5810339095744681 (6991/12032)
batch:186/469, loss:1.0023587601070099, acc:0.5776654411764706 (13827/23936)
batch:279/469, loss:0.9987784639000893, acc:0.5787667410714286 (20743/35840)
batch:372/469, loss:0.9959872414535236, acc:0.5788790214477212 (27638/47744)
batch:465/469, loss:0.991175599492159, acc:0.5813606491416309 (34677/59648)
Test ==> loss:0.2696448630547222, acc:0.9666 (9666/10000)
[Epoch:4]
batch:0/469, loss:0.7853888869285583, acc:0.6875 (88/128)
batch:93/469, loss:0.9590603220970073, acc:0.5880152925531915 (7075/12032)
batch:186/469, loss:0.9617624403958652, acc:0.5840157085561497 (13979/23936)
batch:279/469, loss:0.9641852915287018, acc:0.5849330357142857 (20964/35840)
batch:372/469, loss:0.9647335381354467, acc:0.5839896112600537 (27882/47744)
batch:465/469, loss:0.9653625727467271, acc:0.583506571888412 (34805/59648)
Test ==> loss:0.2166517706988733, acc:0.9722 (9722/10000)
[Epoch:5]
batch:0/469, loss:0.8839341402053833, acc:0.59375 (76/128)
batch:93/469, loss:0.9494094988133045, acc:0.5852726063829787 (7042/12032)
batch:186/469, loss:0.9495056312989424, acc:0.5849348262032086 (14001/23936)
batch:279/469, loss:0.9469595074653625, acc:0.5874720982142857 (21055/35840)
batch:372/469, loss:0.9451027883281656, acc:0.5875502680965148 (28052/47744)
batch:465/469, loss:0.9445154733412255, acc:0.5877648873390557 (35059/59648)
Test ==> loss:0.1983407870689525, acc:0.9739 (9739/10000)
[Epoch:6]
batch:0/469, loss:0.9596455693244934, acc:0.59375 (76/128)
batch:93/469, loss:0.9371558269287678, acc:0.5885970744680851 (7082/12032)
batch:186/469, loss:0.933486777193406, acc:0.5869401737967914 (14049/23936)
batch:279/469, loss:0.9356063387223652, acc:0.5876953125 (21063/35840)
batch:372/469, loss:0.9321824889080774, acc:0.5887231903485255 (28108/47744)
batch:465/469, loss:0.9308745358379102, acc:0.588955203862661 (35130/59648)
Test ==> loss:0.15833796409866477, acc:0.975 (9750/10000)
[Epoch:7]
batch:0/469, loss:0.9392691254615784, acc:0.609375 (78/128)
batch:93/469, loss:0.9271161625994012, acc:0.5857712765957447 (7048/12032)
batch:186/469, loss:0.9292944392418478, acc:0.5856450534759359 (14018/23936)
batch:279/469, loss:0.9243043863347599, acc:0.5865234375 (21021/35840)
batch:372/469, loss:0.9234551397469344, acc:0.5877178284182306 (28060/47744)
batch:465/469, loss:0.9230064044950346, acc:0.5870775214592274 (35018/59648)
Test ==> loss:0.14993384857720968, acc:0.9756 (9756/10000)
[Epoch:8]
batch:0/469, loss:0.963413655757904, acc:0.5703125 (73/128)
batch:93/469, loss:0.9203153658420482, acc:0.5882646276595744 (7078/12032)
batch:186/469, loss:0.9135349348267132, acc:0.5940006684491979 (14218/23936)
batch:279/469, loss:0.9185146927833557, acc:0.5903459821428572 (21158/35840)
batch:372/469, loss:0.9155782517095673, acc:0.5907548592493298 (28205/47744)
batch:465/469, loss:0.9142030574733095, acc:0.5917884924892703 (35299/59648)
Test ==> loss:0.142572170005569, acc:0.9755 (9755/10000)
[Epoch:9]
batch:0/469, loss:0.8696589469909668, acc:0.6171875 (79/128)
batch:93/469, loss:0.9138033472477122, acc:0.5925864361702128 (7130/12032)
batch:186/469, loss:0.9100717417696581, acc:0.5940842245989305 (14220/23936)
batch:279/469, loss:0.9118563358272825, acc:0.5920479910714286 (21219/35840)
batch:372/469, loss:0.9089114867330557, acc:0.592912198391421 (28308/47744)
batch:465/469, loss:0.9042545945859263, acc:0.5938170600858369 (35420/59648)
Test ==> loss:0.13352179668749434, acc:0.9753 (9753/10000)

3.3 训练学生神经网络

# 构建学生网络
s_model = Student().to(device)
# 定义学生网络的优化器
s_optimizer = optim.Adam(s_model.parameters(), lr=learning_rate)
# 训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(s_model, criterion, s_optimizer, train_loader)
    validate(s_model, criterion, test_loader)

输出结果:

[Epoch:0]
batch:0/469, loss:2.357004404067993, acc:0.1875 (24/128)
batch:93/469, loss:1.7883413776438286, acc:0.4625166223404255 (5565/12032)
batch:186/469, loss:1.4309818958216172, acc:0.5944184491978609 (14228/23936)
batch:279/469, loss:1.1971364739750112, acc:0.6667131696428571 (23895/35840)
batch:372/469, loss:1.0451311646453816, acc:0.7116915214477212 (33979/47744)
batch:465/469, loss:0.9389211163884069, acc:0.7416845493562232 (44240/59648)
Test ==> loss:0.46114580276646194, acc:0.8833 (8833/10000)
[Epoch:1]
batch:0/469, loss:0.5253818035125732, acc:0.84375 (108/128)
batch:93/469, loss:0.45607277180286165, acc:0.8744182180851063 (10521/12032)
batch:186/469, loss:0.44302085488237797, acc:0.8799715909090909 (21063/23936)
batch:279/469, loss:0.42559995853475163, acc:0.8859095982142857 (31751/35840)
batch:372/469, loss:0.4132335932261183, acc:0.8879649798927614 (42395/47744)
batch:465/469, loss:0.4041708303943212, acc:0.8905914699570815 (53122/59648)
Test ==> loss:0.34390119095391863, acc:0.9063 (9063/10000)
[Epoch:2]
batch:0/469, loss:0.27973929047584534, acc:0.9140625 (117/128)
batch:93/469, loss:0.3445089725737876, acc:0.9063331117021277 (10905/12032)
batch:186/469, loss:0.33749938792085904, acc:0.9072526737967914 (21716/23936)
batch:279/469, loss:0.3341189743684871, acc:0.9078125 (32536/35840)
batch:372/469, loss:0.33311483519646184, acc:0.9080722184986595 (43355/47744)
batch:465/469, loss:0.32733086157638114, acc:0.9092509388412017 (54235/59648)
Test ==> loss:0.29618216881269144, acc:0.9179 (9179/10000)
[Epoch:3]
batch:0/469, loss:0.3462095856666565, acc:0.8984375 (115/128)
batch:93/469, loss:0.3092869266550592, acc:0.9137300531914894 (10994/12032)
batch:186/469, loss:0.30020128891748543, acc:0.9171958556149733 (21954/23936)
batch:279/469, loss:0.2974837499537638, acc:0.9169084821428571 (32862/35840)
batch:372/469, loss:0.29631629350517774, acc:0.9173927613941019 (43800/47744)
batch:465/469, loss:0.29146103778366367, acc:0.9190584763948498 (54820/59648)
Test ==> loss:0.2715968393449542, acc:0.9234 (9234/10000)
[Epoch:4]
batch:0/469, loss:0.3141135573387146, acc:0.890625 (114/128)
batch:93/469, loss:0.2594475950649444, acc:0.9300199468085106 (11190/12032)
batch:186/469, loss:0.26877123388377105, acc:0.9263870320855615 (22174/23936)
batch:279/469, loss:0.27071290723979474, acc:0.9251674107142858 (33158/35840)
batch:372/469, loss:0.2679654088560442, acc:0.9255403820375335 (44189/47744)
batch:465/469, loss:0.2687629308119864, acc:0.9247250536480687 (55158/59648)
Test ==> loss:0.25750901823556877, acc:0.9282 (9282/10000)
[Epoch:5]
batch:0/469, loss:0.17168357968330383, acc:0.96875 (124/128)
batch:93/469, loss:0.2522462420165539, acc:0.9288563829787234 (11176/12032)
batch:186/469, loss:0.2581011697171844, acc:0.9273897058823529 (22198/23936)
batch:279/469, loss:0.2550680588132569, acc:0.9282924107142857 (33270/35840)
batch:372/469, loss:0.2549391207722173, acc:0.9280956769436998 (44311/47744)
batch:465/469, loss:0.25167093422791476, acc:0.9286145386266095 (55390/59648)
Test ==> loss:0.24905249646192865, acc:0.9306 (9306/10000)
[Epoch:6]
batch:0/469, loss:0.26825278997421265, acc:0.90625 (116/128)
batch:93/469, loss:0.2381658841796378, acc:0.9311003989361702 (11203/12032)
batch:186/469, loss:0.23873208303821278, acc:0.9309408422459893 (22283/23936)
batch:279/469, loss:0.24032189510762691, acc:0.9308035714285714 (33360/35840)
batch:372/469, loss:0.23933973307184495, acc:0.9313421581769437 (44466/47744)
batch:465/469, loss:0.23892121092124557, acc:0.931749597639485 (55577/59648)
Test ==> loss:0.23393861119505727, acc:0.9336 (9336/10000)
[Epoch:7]
batch:0/469, loss:0.3488951325416565, acc:0.921875 (118/128)
batch:93/469, loss:0.23248760267458063, acc:0.9334275265957447 (11231/12032)
batch:186/469, loss:0.22965810611286266, acc:0.9336981951871658 (22349/23936)
batch:279/469, loss:0.22631887934569803, acc:0.9351841517857142 (33517/35840)
batch:372/469, loss:0.22711910750846762, acc:0.9349028150134048 (44636/47744)
batch:465/469, loss:0.2285995162991495, acc:0.9348175965665236 (55760/59648)
Test ==> loss:0.22850063462046127, acc:0.9344 (9344/10000)
[Epoch:8]
batch:0/469, loss:0.2651296854019165, acc:0.90625 (116/128)
batch:93/469, loss:0.22513935580215555, acc:0.9361702127659575 (11264/12032)
batch:186/469, loss:0.22959564474814717, acc:0.9356199866310161 (22395/23936)
batch:279/469, loss:0.2243422551612769, acc:0.9376116071428572 (33604/35840)
batch:372/469, loss:0.2194721238742565, acc:0.9387567024128687 (44820/47744)
batch:465/469, loss:0.22032797501258583, acc:0.9381538358369099 (55959/59648)
Test ==> loss:0.2206512165220478, acc:0.9351 (9351/10000)
[Epoch:9]
batch:0/469, loss:0.2756182849407196, acc:0.921875 (118/128)
batch:93/469, loss:0.2252512621752759, acc:0.9366688829787234 (11270/12032)
batch:186/469, loss:0.22236462188436384, acc:0.936956885026738 (22427/23936)
batch:279/469, loss:0.2201754414077316, acc:0.9374441964285715 (33598/35840)
batch:372/469, loss:0.21374527243123298, acc:0.9391756032171582 (44840/47744)
batch:465/469, loss:0.21209755683789439, acc:0.9397465128755365 (56054/59648)
Test ==> loss:0.21237921054604686, acc:0.9375 (9375/10000)

现在有两个教师网络,一个是正确率为0.9838的神经网络,一个是正确率为0.9753的卷积神经网络;如果简单的训练学生网络,正确率为0.9375 ,那么现在考虑的内容是两个教师网络对学生网络进行训练。

3.4 多教师联合蒸馏训练

下面重新构建多教师网络知识蒸馏的训练函数:

# 训练过程
def train_one_epoch_kd(s_model, t_model, ct_model, hard_loss, soft_loss, optimizer, parms_optimizer, dataloader):
    
    s_model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        
        # 教师网络预测
        with torch.no_grad():
            teacher_1_preds = t_model(image)
            teacher_2_preds = ct_model(image)
        # 学生模型预测
        student_preds = s_model(image)
        
        # 计算与真实标签的损失:hard loss
        student_loss = hard_loss(student_preds, targets)
        # 计算与教师输出的损失:soft loss
        ditillation_1_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_1_preds / temp, dim=1)
        )
        ditillation_2_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_2_preds / temp, dim=1)
        )
        
        # 总损失即为:hard loss与soft loss的加权和
        # 自行选择是否设置可学习参数进行知识蒸馏
#         loss = alpha * ditillation_1_loss + \
#                 gama * ditillation_2_loss + \
#                 (1 - alpha - gama) * student_loss
        loss = 0.2 * ditillation_1_loss + \
                0.2 * ditillation_2_loss + \
                0.6 * student_loss
        train_loss += loss.item()
        
        # 反向更新训练
        optimizer.zero_grad()
#         parms_optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#         parms_optimizer.step()
        
        # 计算正确个数
        _, predicted = student_preds.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)
  • 联合蒸馏训练
# 设置知识蒸馏的超参数
temp = 7      # 蒸馏温度
alpha = nn.Parameter(torch.tensor(0.35), requires_grad=True)   # 权重系数
gama  = nn.Parameter(torch.tensor(0.35), requires_grad=True)   # 权重系数
params = [alpha, gama]     # 设置两个可学习参数
# params = [alpha, ]           # 设置一个可学习参数
print(params)

# 准备新的学生模型的损失函数
hard_loss = nn.CrossEntropyLoss()                # 包含softmax操作
soft_loss = nn.KLDivLoss(reduction="batchmean")  # 不包含softmax操作(所以可以自己设定温度系数)

# 构建蒸馏学生模型
kd_model = Student().to(device)
# 构建蒸馏模型的优化器
kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate)
parms_optimizer = optim.Adam(params, lr=3e-5)
# 利用知识蒸馏来训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]\n".format(epoch), 
          "[weight | alpha:{},gama:{},st:{}]"
              .format(alpha, gama, 1-alpha-gama))
    train_one_epoch_kd(kd_model, t_model, ct_model, hard_loss, soft_loss, kd_optimizer, parms_optimizer, train_loader)
    validate(kd_model, criterion, test_loader)

输出结果:

[Parameter containing:
tensor(0.2000, requires_grad=True), Parameter containing:
tensor(0.2000, requires_grad=True)]
[Epoch:0]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:1.070181131362915, acc:0.1484375 (19/128)
batch:93/469, loss:0.7189208091573512, acc:0.3718417553191489 (4474/12032)
batch:186/469, loss:0.4570929299701344, acc:0.5301637700534759 (12690/23936)
batch:279/469, loss:0.2827999550317015, acc:0.6219587053571428 (22291/35840)
batch:372/469, loss:0.16214995130137527, acc:0.6787449731903485 (32406/47744)
batch:465/469, loss:0.07729014530458164, acc:0.7165034871244635 (42738/59648)
Test ==> loss:0.4550707121438618, acc:0.883 (8830/10000)
[Epoch:1]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.21972870826721191, acc:0.7890625 (101/128)
batch:93/469, loss:-0.3207302990745991, acc:0.8866356382978723 (10668/12032)
batch:186/469, loss:-0.3327708426006338, acc:0.8868231951871658 (21227/23936)
batch:279/469, loss:-0.34525985323957037, acc:0.8899832589285714 (31897/35840)
batch:372/469, loss:-0.3539618340797782, acc:0.891504691689008 (42564/47744)
batch:465/469, loss:-0.36126882240751784, acc:0.8931732832618026 (53276/59648)
Test ==> loss:0.320476306578781, acc:0.9118 (9118/10000)
[Epoch:2]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.42595481872558594, acc:0.8984375 (115/128)
batch:93/469, loss:-0.4049103111028671, acc:0.9043384308510638 (10881/12032)
batch:186/469, loss:-0.4099318265596176, acc:0.9061246657754011 (21689/23936)
batch:279/469, loss:-0.41254837651337894, acc:0.9076450892857143 (32530/35840)
batch:372/469, loss:-0.4177393170208458, acc:0.9097687667560321 (43436/47744)
batch:465/469, loss:-0.4201628497997579, acc:0.91037419527897 (54302/59648)
Test ==> loss:0.2765346652344812, acc:0.9203 (9203/10000)
[Epoch:3]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.47607481479644775, acc:0.9375 (120/128)
batch:93/469, loss:-0.4405997947175452, acc:0.9193816489361702 (11062/12032)
batch:186/469, loss:-0.44154514938114797, acc:0.9194100935828877 (22007/23936)
batch:279/469, loss:-0.44397371081369263, acc:0.9194754464285714 (32954/35840)
batch:372/469, loss:-0.4449578410179302, acc:0.919110254691689 (43882/47744)
batch:465/469, loss:-0.44588555446765965, acc:0.9195111319742489 (54847/59648)
Test ==> loss:0.25252004533628875, acc:0.9255 (9255/10000)
[Epoch:4]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.46281740069389343, acc:0.9296875 (119/128)
batch:93/469, loss:-0.4547458785645505, acc:0.921376329787234 (11086/12032)
batch:186/469, loss:-0.46090967977110714, acc:0.9243816844919787 (22126/23936)
batch:279/469, loss:-0.4611783997288772, acc:0.9245814732142857 (33137/35840)
batch:372/469, loss:-0.46058367931810845, acc:0.9252262064343163 (44174/47744)
batch:465/469, loss:-0.4612702508596903, acc:0.9251944742489271 (55186/59648)
Test ==> loss:0.23789097774255125, acc:0.9292 (9292/10000)
[Epoch:5]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.48149824142456055, acc:0.9375 (120/128)
batch:93/469, loss:-0.47065415629681123, acc:0.9292719414893617 (11181/12032)
batch:186/469, loss:-0.46971190995711054, acc:0.9272225935828877 (22194/23936)
batch:279/469, loss:-0.4715161386345114, acc:0.928515625 (33278/35840)
batch:372/469, loss:-0.472420722326069, acc:0.9290591487935657 (44357/47744)
batch:465/469, loss:-0.47325622069733336, acc:0.9294527896995708 (55440/59648)
Test ==> loss:0.22742461035900477, acc:0.9326 (9326/10000)
[Epoch:6]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.5146299600601196, acc:0.96875 (124/128)
batch:93/469, loss:-0.47914040120358165, acc:0.932845744680851 (11224/12032)
batch:186/469, loss:-0.4779144224317316, acc:0.931442179144385 (22295/23936)
batch:279/469, loss:-0.48003286900264874, acc:0.9315848214285715 (33388/35840)
batch:372/469, loss:-0.48055645328104973, acc:0.9321171246648794 (44503/47744)
batch:465/469, loss:-0.48185253974705805, acc:0.933074034334764 (55656/59648)
Test ==> loss:0.22472369784041296, acc:0.9336 (9336/10000)
[Epoch:7]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.44770073890686035, acc:0.90625 (116/128)
batch:93/469, loss:-0.4870964447234539, acc:0.9355053191489362 (11256/12032)
batch:186/469, loss:-0.48533328188294395, acc:0.9351186497326203 (22383/23936)
batch:279/469, loss:-0.48802188954183034, acc:0.9361607142857142 (33552/35840)
batch:372/469, loss:-0.4895784726251546, acc:0.9360338471849866 (44690/47744)
batch:465/469, loss:-0.4890587088759877, acc:0.9358570278969958 (55822/59648)
Test ==> loss:0.2101580535026291, acc:0.9362 (9362/10000)
[Epoch:8]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.4669145345687866, acc:0.9140625 (117/128)
batch:93/469, loss:-0.4905587165279591, acc:0.9385804521276596 (11293/12032)
batch:186/469, loss:-0.4956321018264893, acc:0.9393382352941176 (22484/23936)
batch:279/469, loss:-0.4955592581204006, acc:0.9392299107142857 (33662/35840)
batch:372/469, loss:-0.4950189735870259, acc:0.9388195375335121 (44823/47744)
batch:465/469, loss:-0.49565435664592383, acc:0.9390423819742489 (56012/59648)
Test ==> loss:0.20754675633167918, acc:0.9388 (9388/10000)
[Epoch:9]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.5581856966018677, acc:0.984375 (126/128)
batch:93/469, loss:-0.5014005094132525, acc:0.9423204787234043 (11338/12032)
batch:186/469, loss:-0.49853202844048566, acc:0.9412182486631016 (22529/23936)
batch:279/469, loss:-0.5006332156913621, acc:0.9414341517857143 (33741/35840)
batch:372/469, loss:-0.5006915050441394, acc:0.9413329423592494 (44943/47744)
batch:465/469, loss:-0.5015091910126895, acc:0.9409703594420601 (56127/59648)
Test ==> loss:0.20188309327711032, acc:0.9411 (9411/10000)

最后,联合训练的效果为0.9411 ,要比单独训练学生模型的效果0.9375要好。

可以看见,经过合适的调参,联合蒸馏训练的效果确实比单独训练学生网络的效果要好。不过前提就是要调好参数。调参的问题是一个血泪的问题,要是调不好,其实并不能发挥出知识蒸馏这个方法的性能,有时往往会出现没有效果的情况,就是用了效果反而还下降了。

而且,这里我本来使用了两个可学习的权重参数分配给教师卷积网络与教师神经网络,但是效果其实并不会比固定权重的训练效果要来得好。经过测试,两个教师网络分配0.2左右的权重效果是最好的。

经过这次的总结,感叹调参是个技术活,深度学习确实像是炼丹一样。


参考资料:

1. 《知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍》
2. 《【29】知识蒸馏(knowledge distillation)测试以及利用可学习参数辅助知识蒸馏训练Student模型》

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-24 00:32:32  更:2022-03-24 00:33:20 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 1:01:25-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码