IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> RuntimeError: Error(s) in loading state_dict for BASE_Transformer -> 正文阅读

[人工智能]RuntimeError: Error(s) in loading state_dict for BASE_Transformer

最近跑一个深度学习变化检测的项目BIT_CD,严格按照作者的说明页进行训练和测试,但是跑出来的模型就是无法正常工作,而用作者的预训练模型就正常工作,百思不得其解,根据错误,逐步调试,输出,总算是找到了问题的所在!

其实这个问题如果对于老手,估计一下子就解决了,但是对于刚刚接触深度学习的新人,要发现并且解决这个问题确实需要费不少功夫。

那么,这里就记录一下这个问题的解决过程!

首先,错误的具体内容如下:

Traceback (most recent call last):
  File "D:/Tools/MyScript/BIT_CD/eval_cd.py", line 58, in <module>
    main()
  File "D:/Tools/MyScript/BIT_CD/eval_cd.py", line 54, in main
    model.eval_models(checkpoint_name=args.checkpoint_name)
  File "D:\Tools\MyScript\BIT_CD\models\evaluator.py", line 158, in eval_models
    self._load_checkpoint(checkpoint_name)
  File "D:\Tools\MyScript\BIT_CD\models\evaluator.py", line 70, in _load_checkpoint
    self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
  File "D:\MyItems\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1483, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BASE_Transformer:
	size mismatch for transformer_decoder.layers.0.0.fn.fn.to_q.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([64, 32]).
	size mismatch for transformer_decoder.layers.0.0.fn.fn.to_k.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([64, 32]).
	size mismatch for transformer_decoder.layers.0.0.fn.fn.to_v.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([64, 32]).
	size mismatch for transformer_decoder.layers.0.0.fn.fn.to_out.0.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([32, 64]).

根据页面提示的错误,可以知道是加载模型中的state_dict维度不匹配。

于是,我尝试输出预训练模型,和我自己训练的模型的内容进行比较。

?结果,我一脸懵逼,看不出个所以然来,维度太高,括号太多,根本没法看。

但是可以注意到,两个模型中都存在model_G_state_dict关键字。

因为这样实在不好看维度,于是我换了一个输出方式,查看了mydict的类型,直接输出维度。

            mydict = checkpoint['model_G_state_dict']

            for k, v in mydict.items():
                print (k, '========' ,v.shape)

此时问题一目了然,中间层的维度不对,跟报错内容一致。

作者的预训练模型中间某个维度是64,32 而我训练出来的模型是512,32。

根据出错的提示

BASE_Transformer

以及transformer_decoder.layers,继续寻找出错位置。

经过整个工程搜索,来到代码位置:

class BASE_Transformer(ResNet):
    """
    Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN
    """
    def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5,
                 token_len=4, token_trans=True,
                 enc_depth=1, dec_depth=1,
                 dim_head=64, decoder_dim_head=64,
                 tokenizer=True, if_upsample_2x=True,
                 pool_mode='max', pool_size=2,
                 backbone='resnet18',
                 decoder_softmax=True, with_decoder_pos=None,
                 with_decoder=True):
        super(BASE_Transformer, self).__init__(input_nc, output_nc,backbone=backbone,
                                             resnet_stages_num=resnet_stages_num,
                                               if_upsample_2x=if_upsample_2x,
                                               )

断点调试,跳出,得到位置:

self.net_G = define_G(args=args, gpu_ids=args.gpu_ids)

输出这个self.net_G查看,可以发现此处的模型与预训练的模型已经不一致,那为何不一致呢?

关键跳转函数:

def define_G(args, init_type='normal', init_gain=0.02, gpu_ids=[]):
    if args.net_G == 'base_resnet18':
        net = ResNet(input_nc=3, output_nc=2, output_sigmoid=False)

    elif args.net_G == 'base_transformer_pos_s4':
        net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
                             with_pos='learned')

    elif args.net_G == 'base_transformer_pos_s4_dd8':
        net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
                             with_pos='learned', enc_depth=1, dec_depth=8)

    elif args.net_G == 'base_transformer_pos_s4_dd8_dedim8':
        net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
                             with_pos='learned', enc_depth=1, dec_depth=8, decoder_dim_head=8)

    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % args.net_G)
    return init_net(net, init_type, init_gain, gpu_ids)

最终定位到原因,原来训练模型和测试模型使用的模型不同!

于是这就好办了,修改测试代码模型参数即可。

    parser.add_argument('--net_G', default='base_transformer_pos_s4_dd8_dedim8', type=str,
                        help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|')

至此,成功训练并测试模型,问题解决!

虽然是个并不复杂的问题,但是我花了一天半的时间才完全解决,解决的过程中经过了无数次的尝试,测试和分析,最终成功解决了问题!

?

回过头看,这其实是一个很简单的问题,但是我仍然很高兴自己能够逐步分析,测试,最终找到正确的解决办法,而做到更深入地理解了代码。

本次遇到的问题比较特别,在网上几乎找不到答案,但是找到了跟我遇到相同问题的人,不管怎么说,这都是一件值得纪念的事情。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-09-04 01:12:01  更:2022-09-04 01:12:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 19:40:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计