IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 知识蒸馏初学例子(只有hard_loss和soft_loss) -> 正文阅读

[人工智能]知识蒸馏初学例子(只有hard_loss和soft_loss)

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm

# 设置随机数种子,便于复现,每次运行的输出结果都一样,因为每次运行rand随机的张量一样
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 使用cuDNN加速卷积运算
# 大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。
# 1.如果网络的输入数据维度或类型上变化不大,设置 torch.backends.cudnn.benchmark = true 可以增加运行效率;
# 2.如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。
torch.backends.cudnn.benchmark = True

# 载入MNIST数据集
# torchvision.datasets这个包中包含MNIST、FakeData、COCO、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR等一些常用的数据集,
# #root(string) - 数据集的根目录在哪里MNIST/processed/training.pt 和 MNIST/processed/test.pt存在。
# train(bool,optional) - 如果为True,则创建数据集training.pt,否则创建数据集test.pt。
# download(bool,optional) - 如果为true,则从Internet下载数据集并将其放在根目录中。如果已下载数据集,则不会再次下载。
# transform(callable ,optional) - 一个函数/转换,它接收PIL图像并返回转换后的版本。例如,transforms.RandomCrop
# target_transform(callable ,optional) - 接收目标并对其进行转换的函数/转换。
train_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),# 将数据转换成tensor
    download=True
)

# 载入测试集

test_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# 生成dataloader
train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
print("train_loader:",train_loader)
test_loader = DataLoader(dataset=test_dataset,batch_size=32,shuffle=False)
print("test_loader:",test_loader)

# 教师模型
class TeacherModel(nn.Module):
    def __init__(self,in_channels=1,num_classes=10):
        super(TeacherModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self,x):
        x = x.view(-1,784) # view中一个参数定为-1,代表动态调整这个维度上的元素个数,以保证元素的总数不变
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

# 从头训练你教师模型
model = TeacherModel()
model = model.to(device)

# print(summary(model))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 6
for epoch in range(epochs):
    model.train()

    # 训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # 前向预测
        preds = model(data)
        loss = criterion(preds, targets)

        # 反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估模型性能
    # model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。
    # 如果使用model.eval()则BN层就不会再计算预测数据的均值和方差,即在预测过程中BN层的均值和方差就是训练过程得到的均值和方差mean_train,variance_train,此时预测结果就不会再发生变化。
    # 预测过程中如果不使用model.eval()的话,依然会使一部分的网络连接不进行计算,而使用model.eval()后就是所有的网络连接均进行计算。
    model.eval()
    num_correct = 0
    num_samples = 0

    # with torch.no_grad计算得到的新tensor的requires_grad为False,grad_fn也为None,即不会求导。
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            # print("x:",x.shape)# [32, 1, 28, 28]
            # print("y",y.shape)# [32]
            # print("preds:",preds.shape)# [32, 10]
            predictions = preds.max(1).indices
            # print("preds.max(1):",preds.max(1))# 有最大值的值,也有最大值的位置,.indices取位置张量
            # print("predictions:",predictions) # [1,32]
            num_correct += (predictions == y).sum()
            # print("num_correct:",num_correct)
            num_samples += predictions.size(0)
            # print("num_samples:",num_samples)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))

teacher_model = model

# 学生模型
class StudentModel(nn.Module):
    def __init__(self,in_channels=1,num_classes=10):
        super(StudentModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self,x):
        x = x.view(-1,784)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

# 从头训练学生模型
model = StudentModel()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 3
for epoch in range(epochs):
    model.train()

    # 训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # 前向预测
        preds = model(data)
        loss = criterion(preds, targets)

        # 反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            predictions = preds.max(1).indices
            num_correct +=(predictions == y).sum()
            num_samples +=predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))

student_model_scratch = model

# 知识蒸馏
# 准备好预训练的教师模型
teacher_model.eval()

# 准备新的学生模型
model = StudentModel()
model = model.to(device)
model.train()

# 蒸馏温度
temp = 7
# hard_loss
hard_loss = nn.CrossEntropyLoss()
# hard_loss 权重
alpha =0.3
# soft_loss
soft_loss = nn.KLDivLoss(reduction="batchmean") # kl散度,差不多是个交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
epochs = 3
for epoch in range(epochs):
    # 训练集上训练训练模型权重
    for data,targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # 教师模型预测
        with torch.no_grad():
            teacher_preds = teacher_model(data)

        # 学生模型预测
        student_preds = model(data)
        # 计算呢hard_loss
        student_loss = hard_loss(student_preds,targets)

        # 计算蒸馏后的预测结果及soft_loss
        ditillation_loss = soft_loss(
            F.softmax(teacher_preds / temp, dim=1),
            F.softmax(student_preds / temp, dim=1)
        )

        # 将 hard_loss和soft_loss加权求和
        loss = alpha * student_loss + (1 - alpha) * ditillation_loss

        # 反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-14 09:57:29  更:2022-05-14 09:59:00 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/1 23:01:29-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码