IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 模型推理详细步骤以及如何排查模型和参数字典对不上的问题:Missing key(s) in state_dict: xxxx -> 正文阅读

[人工智能]模型推理详细步骤以及如何排查模型和参数字典对不上的问题:Missing key(s) in state_dict: xxxx

模型推理详细步骤

模型加载步骤

首先,模型加载总共分为三步,第一步加载网络结构,需要和你训时的network结构一样。

model = Model.FeedBack3(cfg, config_path=None, pretrained=True).to(device)

第二步,加载训练好的参数,实际上虽然我们一直说训练模型,实际上训练出来的就是一组参数,这个参数是一个字典类型,一般保存的名称为xxx.pt或者pth。里面存放的是模型每一层中的权重等数据。pytorch中对于加载参数使torch.load()

pretrained_dict = torch.load('outputmicrosoft-deberta-v3-base_fold3_best.pth')

第三步,将参数加载进模型里

model.load_state_dict(pretrained_dict['model_state_dict'], strict=True)

以上就是加载模型的所有步骤了

关于模型参数和字典对不上的问题

一般报错为:Missing key(s) in state_dict: xxxx
最近在做模型部署的时候发现了这个问题,并且之前也遇到过,由于急于求成就简单实在模型加载参数的时候用了strict=False这样的条件,这个条件会使模型直接忽略所有对不上的参数,本质上没有解决问题。今天在debug时对模型每一层的参数排查终于发现了问题所在。
首先开启debug模式,直接将断点打在模型加载的代码上:
首先查看model的结构有没有问题:
在这里插入图片描述
接下来进行下一步,执行到加载参数字典,同样查看你的参数字典(这里由于参数过多就不详细展示了):
在这里插入图片描述
那么要如何排查呢,具体步骤如下:
首先参数字典里都是以键值对和tensor型式存储的,那么我们只需要一一排查键值对和参数。比如首先是model建,那么只有你加载参数的时候只有加载里面的model建模型才能读到参数,实际上我就是错在这里了,因为我加载的是通常使用的‘model_state_dict’这个建,因为我训练部分是网上复制来的代码,没想到他把参数保存为model。
在这里插入图片描述

也就是我只需要把前面的

model.load_state_dict(pretrained_dict['model_state_dict'])

改成

model.load_state_dict(pretrained_dict['model'])

就行了。
那么如果你的问题不是这里,接下来改如何排查呢
接着看OrderedDict里,这里面是模型每一层的参数,对照方法如下:
在这里插入图片描述
相当于网络结构中的每一层都会变为一个对应的tensor
(model)(embeddings)(LayerNorm)在参数中就会存为:(‘model.embdeddings.LayerNorm’, tensor([xxxxx])

这样就看懂了吧,如此对照每一层网络结构,只要你有耐心,就能找出来具体是那一层不对,不过大多情况下这种在网络中间层出现参数不对的情况很少,出现的原因也肯定是你推理部分加载的网络结构和训练时的网络结构不一致导致的。
顺便推荐一个能帮你排查模型参数的代码,他会输出具体有多少参数使用了和没使用:

def check_keys(model, pretrained_state_dict):
    ckpt_keys = set(pretrained_state_dict.keys())
    model_keys = set(model.state_dict().keys())
    used_pretrained_keys = model_keys & ckpt_keys
    unused_pretrained_keys = ckpt_keys - model_keys
    missing_keys = model_keys - ckpt_keys
    # filter 'num_batches_tracked'
    missing_keys = [x for x in missing_keys
                    if not x.endswith('num_batches_tracked')]
    if len(missing_keys) > 0:
        print('[Warning] missing keys: {}'.format(missing_keys))
        print('missing keys:{}'.format(len(missing_keys)))
    if len(unused_pretrained_keys) > 0:
        print('[Warning] unused_pretrained_keys: {}'.format(
            unused_pretrained_keys))
        print('unused checkpoint keys:{}'.format(
            len(unused_pretrained_keys)))
    print('used keys:{}'.format(len(used_pretrained_keys)))

    assert len(used_pretrained_keys) > 0, \
        'check_key load NONE from pretrained checkpoint'
    return True

模型推理中的数据处理

首先模型推理中数据最终的处理格式要和训练时输入进网络中的格式一致,不过我们通常不再构造新的dataset和使用dataloader,而是直接针对input处理成我们需要的格式。
主要步骤为,读取数据,embedding,增加维度
读取的数据可以是本地存的,如果你是要将模型部署在web上那么数据就是从客户端传来的json格式的数据,因此通常需要先将真正的input取出来。
接下来是向量化,这里步骤和训练中的一致,比如训练中使用了resize([800,800])和toTensor,那么推理中也要这样设置。
由于我是NLP任务,那么处理的步骤为

inputs = cfg.tokenizer.encode_plus(
            input,
            return_tensors=None,
            add_special_tokens=True,
            max_length=cfg.max_lenth,
            pad_to_max_length=True,
            truncation=True
        )
 for k, v in inputs.items():
    inputs[k] = torch.tensor(v, dtype=torch.long)

至此,再次输出此时的tensor和训练时输入进模型的tensor相比,只是少了一个维度,这个维度通常可以理解我们在训练的时候是有batch_size的,而推理时没有,因此要手动升维,升维度的函数有很多,通常使用unsequeeze(1)或者expand:

for k, v in inputs.items():
    s = v.shape
    inputs[k] = v.expand(1,-1).to(device) #-1自动计算

这样处理完数据格式就和训练时完全一致了,说白了还是要先debug一下训练时的数据,看看到底输进去的是什么格式,然后在推理部分照着一点一点改。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-11-05 00:28:48  更:2022-11-05 00:32:42 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年4日历 -2024/4/19 23:59:44-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码