pytorch 保存与加载模型参数的最主要的三个函数
torch.save: 将序列化的对象存储到硬盘中.此函数使用Python的pickle实用程序进行序列化. 对于数据类型都可以进行序列化存储, 模型, 张量, 以及字典, 等各种数据对象都可以使用该函数存储.
torch.load: 该函数使用的是 pickle 的阶序列化过程, 并将结果存如内存中, 该函数也促进设备加载数据.
torch.nn.Module.load_state_dict: 使用反序列化的 state_dict 加载模型的参数字典
1. 权重参数保存的三种方式
第一种: 将网络模型和对应的参数保存在一起;
第二种: 模型和参数分离, 单独的保存模型的权重参数;
方式二推荐, 方便于网络模型修改后, 提取出对应层的参数权重;
第三中: 除了权重参数, 用于模型训练的超参数也保存其中。
1.1 模型和参数保存在一起
- 保存时:
net123 = module.CustomModel()
torch.save( net123,"./weights/All_in.pth")
- 加载模型时:
net123 = torch.load("./weights/All_in.pth")
model.eval()
使用该方法相当于跳过了对模型的 state_dict 描述的过程, 而是直接使用 python 的 pickle 包,
这种方法的缺点是, 模型的存储形式与加载形式十分固定, 这样做的原因是因为pickle不会保存模型类本身. 而是存出来包含该文件的路径,该路径在加载时使用.
因此,在其他项目中使用或重构后,代码可能会以各种方式中断. 但是这种方法存储的文件的类型与前面的方法一样. 同样, 以该方法加载模型运行之前需要调用model.eval() .
1.2 单独保存模型的参数-state_dict()
- 保存时:
net123 = module.CustomModel()
torch.save(net123.state_dict(),'./weights/epoch_weight.pth')
- 加载模型时:
net123 = module.CustomModel(*args, **kwargs)
net123.load_state_dict(torch.load('epoch_weight.pth'))
model.eval()
从模型存储的角度, 存储模型的时候, 唯一需要存储的是该模型训练的参数, torch.save() 函数也可以存储模型的 state_dict . 使用该方法进行存储, 模型被看做字典形式, 所以对模型的操作更加灵活.
在这种形式下常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型.
注意, 加载模型之后, 并不能直接运行, 需要使用 model.eval() 函数设置 Dropout 与层间正则化. 另一方面, 该方法在存储模型的时候是以字典的形式存储的, 也就是存储的是模型的字典数据, Pytorch 不能直接将模型读取为该形式, 必须先torch.load() 该模型, 然后再使用 load_state_dict() .
1.3 模型的参数+ 模型训练超参数
checkpoint 方式
- 保存时:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
- 加载模型时:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
model.train()
可以看出 checkpoints 是模型主要内容的一个字典, 基本包含了模型各种数据,
例如上面的例子模型的参数使用的是 optimizer.state_dict().
存储 checkpoints 主要目的是为了方便加载模型继续训练, 将所有的信息存储, 加载模型继续训练的时候就会更加方便.
为了存储一个训练过程的多种信息, 最好的方式是使用 dictionary 进行序列化, 这样存储一个训练模型的形式是 .tar ,
要加载项目,首先初始化模型和优化器,然后使用torch.load() 在本地加载字典.从这里开始, 只需按期望查询字典即可轻松访问已保存的项目.
请记住,在运行推理之前,必须调用model.eval() 来将 Dropout 和 Batch 正则化设置为评估模式, 不这样做将产生不一致的推断结果. 如果恢复训练,那么调用model.train() 以确保这些层处于训练模式.
2. 多个模型权重信息存储在一个文件中
- 保存时:
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
- 加载模型时:
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
modelA.train()
modelB.train()
保存包含多个 torch.nn.Modules 的模型(例如GAN,序列到序列模型或模型集合)时,将采用与保存常规检查点相同的方法。
换句话说,保存每个模型的state_dict和相应的优化器的字典. 如前所述,您可以保存任何其他可以帮助您恢复培训的项目,只需将它们添加到字典中即可. 使用该方法存储的文件也是 .tar 形式的, 要加载模型,请首先初始化模型和优化器,然后使用torch.load()在本地加载字典。 从这里,您只需按期望查询字典即可轻松访问已保存的项目.
3. 跨平台参数保存与加载
3.1 save on GPU , load on GPU
1.Save:
torch.save(model.state_dict(), PATH)
- Load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
3.1 save on GPU , load on CPU
1.Save:
torch.save(model.state_dict(), PATH)
- Load:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
3.1 save on CPU , load on GPU
1.Save:
torch.save(model.state_dict(), PATH)
- Load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.to(device)
4. 加载模型举例
encoder = TransformerModel(params, dico, is_encoder=True, with_output=False)
decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)
if params.reload_emb != '':
word2id, embeddings = load_embeddings(params.reload_emb, params)
set_pretrain_emb(encoder, dico, word2id, embeddings)
set_pretrain_emb(decoder, dico, word2id, embeddings)
set_pretrain_emb(model2, dico, word2id, embeddings)
if params.reload_model != '':
enc_path, dec_path = params.reload_model.split(',')
assert not (enc_path == '' and dec_path == '')
if enc_path != '':
enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
if all([k.startswith('module.') for k in enc_reload.keys()]):
enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
encoder.load_state_dict(enc_reload, strict=False)
对于该部分, 本文只是做了个简单的例子介绍, 更详细的内容参见传送门 . 对于这个传送门的例子, 如果我们先存储一个大模型, 将大模型加载到小模型的时候, 使用:
path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path), strict=False)
for module in model.named_modules():
print(module)
for name, param in model.named_parameters():
print(name, param)
reference
|