优点:
1.长时间的训练,如果发生中断,继续训练时直接读取 2.通过迁移学习,利用别人训练好的数据进行训练,提高训练效果
三个方面说明
1.模型保存与加载 2.冻结一部分参数,训练另一部分参数 3.采用不同的学习率进行训练
模型保存与加载
模型保存与加载的三种方式
torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH))
model.eval()
torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'loss': loss,
...}, PATH)
checkpoint = torch.load(PATH)
start_epoch = checkpoint['epoch']
model.load_state_dict(chechpoint['model_state_dict'])
model.train()
model.eval()
状态字典:state_dict,python的字典对象,保存、更新、修改和恢复, torch.nn.Module模型的可学习参数包含在模型中(使用model.parameters()进行访问); torch.optim中,包含优化器的状态信息、使用的超参数等
注意:只有可学习参数的层的模型才具有state_dict。
测试代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
class ModelClass(nn.Module):
def __init__(self):
super(ModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = ModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("------------model's state_dict---------------------")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("-------------optimizer's state_dict--------------------")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
print('===> try resume from checkpoint')
if os.path.isdir('checkpoint'):
try:
checkpoint = torch.load('./XXX/xxx.pth')
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
print('===> load last checkpoint data')
except FileNotFoundError:
print(" can't found file")
else:
start_epoch = 0
print('===> start from scratch')
def load_ckeckpoint(model, checkpoint, optimizer, loadOptimizer):
if checkpoint != 'No':
print('loading checkpoint')
model_dict = model.state_dict()
model_checkpoint = torch.load(checkpoint)
pretrained_dict = model_checkpoint['model_state_dict']
new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
model_dict.update(new_dict)
print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
model.load_state_dict(model_dict)
print('--------------loaded finished!--------------')
if loadOptimizer == True:
optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])
print('---------------loaded optimizer-------------')
else:
print('not load optimizer')
else:
print('no checkpoint is included')
return model, optimizer
|