IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习中模型权重数据读取、保存 -> 正文阅读

[人工智能]深度学习中模型权重数据读取、保存

作者:recommend-item-box type_blog clearfix

优点:

1.长时间的训练,如果发生中断,继续训练时直接读取
2.通过迁移学习,利用别人训练好的数据进行训练,提高训练效果

三个方面说明

1.模型保存与加载
2.冻结一部分参数,训练另一部分参数
3.采用不同的学习率进行训练

模型保存与加载

模型保存与加载的三种方式

# 方式一:保存与加载整个state_dict(推荐)
torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH))  # 继承自torch.nn.Module.load_state_dict
# 测试时不启用BatchNormalization 和DropOut
model.eval()
# 方式二:保存加载整个模型
torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
			'optimizer_state_dict': optimizer.state_dict(), 'loss': loss,
			 ...}, PATH)
checkpoint = torch.load(PATH)
start_epoch = checkpoint['epoch']
model.load_state_dict(chechpoint['model_state_dict'])
# 训练时
model.train()
# 测试时
model.eval()

状态字典:state_dict,python的字典对象,保存、更新、修改和恢复,
torch.nn.Module模型的可学习参数包含在模型中(使用model.parameters()进行访问);
torch.optim中,包含优化器的状态信息、使用的超参数等

注意:只有可学习参数的层的模型才具有state_dict。

测试代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os

class ModelClass(nn.Module):
    def __init__(self):
        super(ModelClass, self).__init__()
        # self.pointnet = model(pretrained) # 冻结部分参数,
        # for p in self.parameters():
        #     p.requires_grad = False
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = ModelClass()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("------------model's state_dict---------------------")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("-------------optimizer's state_dict--------------------")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

print('===> try resume from checkpoint')
if os.path.isdir('checkpoint'):
    try:
        checkpoint = torch.load('./XXX/xxx.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        print('===> load last checkpoint data')
    except FileNotFoundError:
        print(" can't found file")
    else:
        start_epoch = 0
        print('===> start from scratch')

# 冻结部分参数, 训练另一部分参数
    # 1.在模型中添加参数
        # self.pointnet = model(pretrained) # 冻结部分参数,
        # for p in self.parameters():
        #     p.requires_grad = False

        # 1.1 for i, p in enumerate(model.parameters()):
                # if i < xxx:
                # p.requires_grad = False
    # 2.在优化器中添加
        # filter(lambda p: p.requires_grad, model.parameters())

# 在模型中修改部分网络参数,如增减等,则需要过滤这些参数、加载方式

def load_ckeckpoint(model, checkpoint, optimizer, loadOptimizer):
    if checkpoint != 'No':
        print('loading checkpoint')
        model_dict = model.state_dict()
        model_checkpoint = torch.load(checkpoint)
        pretrained_dict = model_checkpoint['model_state_dict']
        # 过滤操作
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)
        # 打印更新的参数
        print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
        model.load_state_dict(model_dict)
        print('--------------loaded finished!--------------')

        if loadOptimizer == True:
            optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])
            print('---------------loaded optimizer-------------')
        else:
            print('not load optimizer')

    else:
        print('no checkpoint is included')

    return model, optimizer

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-06 11:07:54  更:2021-09-06 11:08:24 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 20:00:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码