之前想学习保存和加载模型的代码,在知乎上看到一个回答,发现两行代码就可以搞定,于是兴冲冲的加上了:
torch.save(model, "model.pth.tar")
model_dict=torch.load("model.pth.tar")
然后就大胆的去训练了,结果训练结束,准备load 时,发现load 得到的结果,就只有模型的结构,参数完全没保存下来… (哎,当时看到答主说这种方式是保存了整个网络,就以为整个网络必然包括参数啊,谁知道仅仅是结构)
于是换了一种方式:
checkpoint = {
"model_struct": model,
"model_param": model.state_dict(),
"model_cfg": config}
torch.save(checkpoint, “model.ckpt")
这种方式是自己建立一个字典checkpoint ,然后分别保存模型结构model_struct 、模型参数model_param 和相关配置model_cfg ,然后保存为.ckpt 文件(至于为何.pth.tar 和.ckpt 到底有什么不一样,暂时还不清楚)
这次的教训就是,单保存模型是不能保存网络参数的,需要调用模型的.state_dict() 属性将参数拿出来
参考文章: https://zhuanlan.zhihu.com/p/38056115
|