import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
# 设置随机数种子,便于复现,每次运行的输出结果都一样,因为每次运行rand随机的张量一样
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
# 大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。
# 1.如果网络的输入数据维度或类型上变化不大,设置 torch.backends.cudnn.benchmark = true 可以增加运行效率;
# 2.如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。
torch.backends.cudnn.benchmark = True
# 载入MNIST数据集
# torchvision.datasets这个包中包含MNIST、FakeData、COCO、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR等一些常用的数据集,
# #root(string) - 数据集的根目录在哪里MNIST/processed/training.pt 和 MNIST/processed/test.pt存在。
# train(bool,optional) - 如果为True,则创建数据集training.pt,否则创建数据集test.pt。
# download(bool,optional) - 如果为true,则从Internet下载数据集并将其放在根目录中。如果已下载数据集,则不会再次下载。
# transform(callable ,optional) - 一个函数/转换,它接收PIL图像并返回转换后的版本。例如,transforms.RandomCrop
# target_transform(callable ,optional) - 接收目标并对其进行转换的函数/转换。
train_dataset = torchvision.datasets.MNIST(
root="dataset/",
train=True,
transform=transforms.ToTensor(),# 将数据转换成tensor
download=True
)
# 载入测试集
test_dataset = torchvision.datasets.MNIST(
root="dataset/",
train=False,
transform=transforms.ToTensor(),
download=True
)
# 生成dataloader
train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
print("train_loader:",train_loader)
test_loader = DataLoader(dataset=test_dataset,batch_size=32,shuffle=False)
print("test_loader:",test_loader)
# 教师模型
class TeacherModel(nn.Module):
def __init__(self,in_channels=1,num_classes=10):
super(TeacherModel,self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784,1200)
self.fc2 = nn.Linear(1200, 1200)
self.fc3 = nn.Linear(1200, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self,x):
x = x.view(-1,784) # view中一个参数定为-1,代表动态调整这个维度上的元素个数,以保证元素的总数不变
x = self.fc1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
return x
# 从头训练你教师模型
model = TeacherModel()
model = model.to(device)
# print(summary(model))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 6
for epoch in range(epochs):
model.train()
# 训练集上训练模型权重
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
# 前向预测
preds = model(data)
loss = criterion(preds, targets)
# 反向传播,优化权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试集上评估模型性能
# model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。
# 如果使用model.eval()则BN层就不会再计算预测数据的均值和方差,即在预测过程中BN层的均值和方差就是训练过程得到的均值和方差mean_train,variance_train,此时预测结果就不会再发生变化。
# 预测过程中如果不使用model.eval()的话,依然会使一部分的网络连接不进行计算,而使用model.eval()后就是所有的网络连接均进行计算。
model.eval()
num_correct = 0
num_samples = 0
# with torch.no_grad计算得到的新tensor的requires_grad为False,grad_fn也为None,即不会求导。
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
# print("x:",x.shape)# [32, 1, 28, 28]
# print("y",y.shape)# [32]
# print("preds:",preds.shape)# [32, 10]
predictions = preds.max(1).indices
# print("preds.max(1):",preds.max(1))# 有最大值的值,也有最大值的位置,.indices取位置张量
# print("predictions:",predictions) # [1,32]
num_correct += (predictions == y).sum()
# print("num_correct:",num_correct)
num_samples += predictions.size(0)
# print("num_samples:",num_samples)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))
teacher_model = model
# 学生模型
class StudentModel(nn.Module):
def __init__(self,in_channels=1,num_classes=10):
super(StudentModel,self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784, 20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self,x):
x = x.view(-1,784)
x = self.fc1(x)
# x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
# x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
return x
# 从头训练学生模型
model = StudentModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 3
for epoch in range(epochs):
model.train()
# 训练集上训练模型权重
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
# 前向预测
preds = model(data)
loss = criterion(preds, targets)
# 反向传播,优化权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试集上评估模型性能
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct +=(predictions == y).sum()
num_samples +=predictions.size(0)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))
student_model_scratch = model
# 知识蒸馏
# 准备好预训练的教师模型
teacher_model.eval()
# 准备新的学生模型
model = StudentModel()
model = model.to(device)
model.train()
# 蒸馏温度
temp = 7
# hard_loss
hard_loss = nn.CrossEntropyLoss()
# hard_loss 权重
alpha =0.3
# soft_loss
soft_loss = nn.KLDivLoss(reduction="batchmean") # kl散度,差不多是个交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
epochs = 3
for epoch in range(epochs):
# 训练集上训练训练模型权重
for data,targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
# 教师模型预测
with torch.no_grad():
teacher_preds = teacher_model(data)
# 学生模型预测
student_preds = model(data)
# 计算呢hard_loss
student_loss = hard_loss(student_preds,targets)
# 计算蒸馏后的预测结果及soft_loss
ditillation_loss = soft_loss(
F.softmax(teacher_preds / temp, dim=1),
F.softmax(student_preds / temp, dim=1)
)
# 将 hard_loss和soft_loss加权求和
loss = alpha * student_loss + (1 - alpha) * ditillation_loss
# 反向传播,优化权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试集上评估模型性能
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))
?
|