IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习网络模型的保存与加载 -> 正文阅读

[人工智能]深度学习网络模型的保存与加载

pytorch 保存与加载模型参数的最主要的三个函数

torch.save: 将序列化的对象存储到硬盘中.此函数使用Python的pickle实用程序进行序列化. 对于数据类型都可以进行序列化存储, 模型, 张量, 以及字典, 等各种数据对象都可以使用该函数存储.

torch.load: 该函数使用的是 pickle 的阶序列化过程, 并将结果存如内存中, 该函数也促进设备加载数据.

torch.nn.Module.load_state_dict: 使用反序列化的 state_dict 加载模型的参数字典

1. 权重参数保存的三种方式

第一种: 将网络模型和对应的参数保存在一起;

第二种: 模型和参数分离, 单独的保存模型的权重参数;

方式二推荐, 方便于网络模型修改后, 提取出对应层的参数权重

第三中: 除了权重参数, 用于模型训练的超参数也保存其中。

1.1 模型和参数保存在一起

  1. 保存时:
 net123 = module.CustomModel()
# CustomModel 是自己定义的模型类, 放在 module 的文件中;

torch.save( net123,"./weights/All_in.pth")

  1. 加载模型时:

# Model class must be defined somewhere
net123 = torch.load("./weights/All_in.pth")
model.eval()

使用该方法相当于跳过了对模型的 state_dict 描述的过程, 而是直接使用 python 的 pickle 包,

这种方法的缺点是, 模型的存储形式与加载形式十分固定, 这样做的原因是因为pickle不会保存模型类本身. 而是存出来包含该文件的路径,该路径在加载时使用.

因此,在其他项目中使用或重构后,代码可能会以各种方式中断. 但是这种方法存储的文件的类型与前面的方法一样. 同样, 以该方法加载模型运行之前需要调用model.eval() .

1.2 单独保存模型的参数-state_dict()

  1. 保存时:
net123 = module.CustomModel()
# CustomModel 是自己定义的模型类, 放在 module 的文件中;

torch.save(net123.state_dict(),'./weights/epoch_weight.pth')

  1. 加载模型时:

net123 = module.CustomModel(*args, **kwargs)
net123.load_state_dict(torch.load('epoch_weight.pth'))
model.eval()

从模型存储的角度, 存储模型的时候, 唯一需要存储的是该模型训练的参数, torch.save() 函数也可以存储模型的 state_dict. 使用该方法进行存储, 模型被看做字典形式, 所以对模型的操作更加灵活.

在这种形式下常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型.

注意, 加载模型之后, 并不能直接运行, 需要使用 model.eval() 函数设置 Dropout 与层间正则化. 另一方面, 该方法在存储模型的时候是以字典的形式存储的, 也就是存储的是模型的字典数据, Pytorch 不能直接将模型读取为该形式, 必须先torch.load()该模型, 然后再使用 load_state_dict().

1.3 模型的参数+ 模型训练超参数

checkpoint 方式

  1. 保存时:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)


  1. 加载模型时:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

可以看出 checkpoints 是模型主要内容的一个字典, 基本包含了模型各种数据,

例如上面的例子模型的参数使用的是 optimizer.state_dict().

存储 checkpoints 主要目的是为了方便加载模型继续训练, 将所有的信息存储, 加载模型继续训练的时候就会更加方便.

为了存储一个训练过程的多种信息, 最好的方式是使用 dictionary 进行序列化, 这样存储一个训练模型的形式是 .tar,

要加载项目,首先初始化模型和优化器,然后使用torch.load() 在本地加载字典.从这里开始, 只需按期望查询字典即可轻松访问已保存的项目.

请记住,在运行推理之前,必须调用model.eval() 来将 Dropout 和 Batch 正则化设置为评估模式, 不这样做将产生不一致的推断结果. 如果恢复训练,那么调用model.train() 以确保这些层处于训练模式.

2. 多个模型权重信息存储在一个文件中

  1. 保存时:
torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

  1. 加载模型时:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

保存包含多个 torch.nn.Modules 的模型(例如GAN,序列到序列模型或模型集合)时,将采用与保存常规检查点相同的方法。

换句话说,保存每个模型的state_dict和相应的优化器的字典. 如前所述,您可以保存任何其他可以帮助您恢复培训的项目,只需将它们添加到字典中即可. 使用该方法存储的文件也是 .tar 形式的, 要加载模型,请首先初始化模型和优化器,然后使用torch.load()在本地加载字典。 从这里,您只需按期望查询字典即可轻松访问已保存的项目.

3. 跨平台参数保存与加载

3.1 save on GPU , load on GPU

1.Save:

torch.save(model.state_dict(), PATH)
  1. Load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

3.1 save on GPU , load on CPU

1.Save:

torch.save(model.state_dict(), PATH)
  1. Load:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

3.1 save on CPU , load on GPU

1.Save:

torch.save(model.state_dict(), PATH)
  1. Load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model


4. 加载模型举例

# build
        encoder = TransformerModel(params, dico, is_encoder=True, with_output=False)  # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
        decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)
        
# reload pretrained word embeddings
        if params.reload_emb != '':
        	# 表示加载预训练模型
            word2id, embeddings = load_embeddings(params.reload_emb, params)
            set_pretrain_emb(encoder, dico, word2id, embeddings)
            set_pretrain_emb(decoder, dico, word2id, embeddings)
            set_pretrain_emb(model2, dico, word2id, embeddings)

        # reload a pretrained model
        if params.reload_model != '':
            enc_path, dec_path = params.reload_model.split(',')
            assert not (enc_path == '' and dec_path == '')

            # reload encoder
            if enc_path != '':
                enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
                # 预训练模型是在 GPU 上训练的
                enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
                # 导入存储的文件的模型
                if all([k.startswith('module.') for k in enc_reload.keys()]):
                    enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
                # 这个过程相当于将model 反序列化为 state_dict的形式
                encoder.load_state_dict(enc_reload, strict=False)
                # 这个后面的 strict=False 就是对 encoder 与 enc_reload.state_dict之间差异进行处理, 如果encoder 的模型结构与 enc_reload模型结构
                # 不一样的时候, 就会向 encoder 转化, 也就是 encoder 不包含的层就不会导入, 例如这里 enc_reload 就是一个完整的 Transformer 模型, 但是
                # encoder 是不包含输出部分的, 所以就不会加载这部分

对于该部分, 本文只是做了个简单的例子介绍, 更详细的内容参见传送门 . 对于这个传送门的例子, 如果我们先存储一个大模型, 将大模型加载到小模型的时候, 使用:

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path), strict=False)
for module in model.named_modules():
    print(module)
for name, param in model.named_parameters():
    print(name, param)

reference

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-12 16:27:40  更:2022-05-12 16:28:35 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/1 23:00:43-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码