IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法 -> 正文阅读

[人工智能]Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

需求

Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层。(权重文件存储为dict形式)

方法一

常见方法:加载权重时用if对网络层进行筛选

'''
# model为定义的网络结构:
class model(nn.Module):
    def __init__(self):
        super(model,self).__init__()
        ……

    def forward(self,x):
        ……
        return x
'''

model = model()  
# load存在的模型参数(权重文件),后缀名可能不同? ? 
pretrained_dict = torch.load('model.pkl')
model_dict = model.state_dict()
# 关键在于下面这句,从model_dict中读取key、value时,用if筛选掉不需要的网络层 
pretrained_dict = {key: value for key, value in pretrained_dict.items() if (key in model_dict and 'Prediction' not in key)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

方法二

不完全匹配,只加载权重中存在的参数,不匹配就跳过

# load_state_dict() 默认strict=True,需要完全匹配,否则报错
# 修改为strict=False后,只匹配存在的参数
pretrained_dict = torch.load(weight_path)
model.load_state_dict(pretrained_dict, strict=False)

方法三

?不使用原有权重文件训练,对原有权重文件进行拷贝,拷贝文件中只包含需要的网络层,后续直接利用拷贝权重文件进行训练。

    # 对原有权重文件进行拷贝,拷贝文件中只包含需要的网络层,
    # 后续直接利用拷贝文件进行训练。
    import pickle

    model = model()
    net = model
    path_weight = 'R-50.pkl'
    path_weight2 = 'R2-50.pkl'

    with open(path_weight,'rb') as f:
        obj=f.read()
    # 用pickle.loads()加载权重信息
    la_obj=pickle.loads(obj,encoding='latin1')
    # 用if进行筛选
    weights= {key: value for key, value in la_obj.items()}
              #if key in la_obj and 'backbone.bottom_up.stem.conv1.weight' not in key}
    # 使用print查看权重文件信息 
    print(weights)
    
    # 再深拷贝一份文件保存
    state_dict = copy.deepcopy(weights)
    with open(path_weight2,'wb') as f2:
        pickle.dump(state_dict, f2)

    # 可以写入txt,便于查看信息
    path_weight2 = 'R2-101.txt'
    inf = str(state_dict)
    ff = open(path_weight2,'w')
    ff.write(inf)

下面是对载入参数的优化有特殊要求:参数固定、或者参数更新速度不同。

方法四

如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
    if name 满足某些条件:
        value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

方法五

如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
? ? print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? {'params':model.decoder.parameters()}
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的learning_rate=1e-6。
在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。
?

遇见的问题

torch.load 加载权重文件时报错 Magic Number Error?

有时候使用 torch.load 加载比较古老的权重文件时可能报错 Magic Number Error,这有可能是因为该文件使用 pickle 存储并且编码使用了 latin1,此时可以这样加载:

若要进行筛选,同理可以在后面加上if进行判断。

import pickle
with open(weights_path, 'rb') as f:
    obj = f.read()
# 用pickle进行load,编码方式为latin1
weights = {key: weight_dict for key, weight_dict in pickle.loads(obj,encoding='latin1').items()}
# 同理,可以用if判断进行筛选
# weights = {key: value for key, value in pickle.loads(obj,encoding='latin1').items() if (key in model_dict and 'Prediction' not in key)}
model.load_state_dict(weights) 

TypeError: a bytes-like object is required, not 'str'

python3和python2在套接字返回值解码上有区别。

套接字就是?socket,用于描述 IP 地址和端口,应用程序通过套接字向网络发出请求或者应答网络请求,可以认为是计算机网络的数据接口。目前套接字分为两种:基于文件型和基于网络型。

解决方法

使用函数 encode() 和decode():

  1. str 通过 encode() 函数编码变为 bytes
  2. bytes 通过 decode() 函数编码变为 str。(当我们从网络或磁盘上读取了字节流,则读到的数据就是 bytes)

补充:

str --> bytes

# 声明一个字符串s:
>>> s = 'abc'
>>> type(s)
<class 'str'>

# 四种转换方式:
>>> b1 = s.encode()
>>> type(b1)
<class 'bytes'>
>>> b2 = str.encode(s)
>>> type(b2)
<class 'bytes'>
>>> b3 = s.encode(encoding='utf-8')
>>> type(b3)
<class 'bytes'>
>>> b4 = bytes(s,encoding='utf-8')
>>> type(b4)
<class 'bytes'>

bytes --> str

# 声明一个bytes:
>>> b = b'abc'
>>> type(b)
<class 'bytes'>

# 三种转换方式:
>>> s1 = bytes.decode(b)
>>> type(s1)
<class 'str'>
>>> s2 = b.decode()
>>> type(s2)
<class 'str'>
>>> s3 = str(b,encoding='utf-8')
>>> type(s3)
<class 'str'>

参考博客

Pytorch中只导入部分层权重的方法_汐梦聆海的博客-CSDN博客_pytorch加载部分权重

pytorch微调模型—只加载预训练模型的某些层_农夫山泉2号的博客-CSDN博客

Pytorch加载模型不完全匹配 & 只加载部分参数权重 load_hxxjxw的博客-CSDN博客_pytorch加载模型不匹配跳过

pytorch载入预训练模型后,只想训练个别层怎么办?_慕白-的博客-CSDN博客_pytorch只训练最后一层

PyTorch | 保存和加载模型 - 知乎 (zhihu.com)

torch.load加载权重时报错 Magic Number Error - 仰望高端玩家的小清新 - 博客园 (cnblogs.com)

Python报错:TypeError: a bytes-like object is required, not ‘str‘_程序媛三妹的博客-CSDN博客?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-30 08:42:53  更:2022-04-30 08:45:59 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 17:38:35-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码