参考资料: Pytorch官方文档链接 某博客
1. 3个函数
- torch.save() : 讲一个序列化对象保存到磁盘中。使用python的pickle工具。 模型 (model)、张量 (tensor) 和各种对象的字典 (dict) 都可以用这个函数保存。
- torch.load() :将pickle对象文件反序列化到内存中,也便于将数据加载到设备中
- torch.nn.Module.load_state_dict():加载模型的参数 state_dict介绍
Pytorch中,torch.nn.Module里面的可学习参数(weights和biases)都存在model.parameters()中。
2. 模型不同后缀名的区别
pytorch常见保存模型文件的后缀名有 .pt , .pth,.pkl。其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save()函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth或.pkl.用相同的torch.save()语句保存出来的模型文件没有什么不同。
3. 保存和重载模型
保存模型主要有两种方式: (1)只保存模型的参数,之后使用时再重新构建一个同样结构的新模型,然后把保存的参数导入新模型。(推荐)
torch.save(**model.state_dict()**,PATH)
model = TheModelClass(*args,**kwargs)
**model.load_state_dict(torch.load(PATH))**
model.eval()
(2)将整个模型保存下来,然后直接加载整个模型。(有点耗费内存…)
torch.save(**model**,PATH)
**model = torch.load(PATH)**
model.eval()
(3)如果没有训练完,仍然需要继续训练,除了model_state_dict需要保存,还需要保存optimizer_state_dict,epoch和loss。
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
model.train()
|