IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> # AssertionError: The `num_classes` (80) in Shared2FCBBoxHead of MMDataParallel does not matche -> 正文阅读

[Python知识库]# AssertionError: The `num_classes` (80) in Shared2FCBBoxHead of MMDataParallel does not matche

我看很多人都遇到了这个问题,有很多解决了的。我就把这篇博文再完善一下,让大家对mmdetection使用得心应手。


mmdetection训练自己的数据集时报错 ?? :

# AssertionError: The `num_classes` (3) in Shared2FCBBoxHead of MMDataParallel does not matches the length of `CLASSES` 80) in CocoDataset

你可能已经修改了以下两个文件,但是还是报错:

mmdetection-master\mmdet\core\evaluation\class_names.py

mmdetection-master\mmdet\datasets\coco.py

意思就是你指定的类别(3种)与CocoDataset的类别(80种)不匹配。

如果是报错翻过来的话,也就是你指定的类别(80种)与CocoDataset的类别(3种)不匹配。一定是配置文件里设置错了,去你的配置文件搜索num_classes,然后修改好。


废话不多说,直接上方法。有以下几种方法【经过我多次使用后,推荐第四种,方便的很】:

1?? 是修改最少的,假设你有2个类,你就把上边两处地方,前2个类替换成你的类别。方法比较简单,但是可能存在隐患。【不推荐】

2?? 第二种方法就是修改完 class_names.py 和 voc.py 之后一定要重新编译代码(运行python setup.py install),再进行训练。

我试了,有时候可以,有时候不行,可以尝试一下。

参考:

新版 MMDetection V2.3.0训练测试笔记 - it610.com

mmdetectionV2.x版本 训练自己的VOC数据集_桃子酱momo的博客-CSDN博客

3?? 第三种方法,我之前使用的方法,其实跟重新编译一样,重新编译的原因就是因为环境里的源文件没有修改,所以你才会报错。mmdetection-master目录下只是一些python文件,真正运行程序时,运行的还是环境里的源文件,因为我们直接去环境里修改源文件。

假设我的conda环境名为conda_env_name,因此去下面的目录下,分别修改两个文件:

\anaconda3\envs\conda_env_name\lib\python3.7\site-packages\mmdet\core\evaluation\class_names.py

\anaconda3\envs\conda_env_name\lib\python3.7\site-packages\mmdet\datasets\coco.py

在conda环境里把这两个文件里的类别修改了,就可以了,这一招一定可以。

4?? 第四种方法,更简单,更方便,我现在使用的方法。直接在mmdetection配置文件中指定好所有要指定的东西,因为在mmdetection中配置文件的参数值优先级是最高的,所以不用管环境里有没有修改,配置文件里修改了,就可以了。我写了个脚本,把脚本放到mmdetection根目录,根据自己要用的模型,把脚本中的变量都改成自己的。

我以cascade_mask_rcnn_r101为例:

# 在mmdetection的根目录下运行,如果报错:没有那个参数,就把create_mm_config中那个参数赋值给注释掉。生成配置文件后,直接修改配置文件就可以了。
import os
from mmcv import Config

#################################  下边是要修改的内容   ####################################

root_path = os.getcwd()
model_name = 'cascade_mask_rcnn_r101'  # 改成自己要使用的模型名字
work_dir = os.path.join(root_path, "work_dirs", model_name)  # 训练过程中,保存日志权重文件的路径,。
baseline_cfg_path = os.path.join('configs', 'cascade_rcnn', 'cascade_mask_rcnn_r101_fpn_mstrain_3x_coco.py')
# 改成自己要使用的模型的配置文件路径
save_cfg_path = os.path.join(work_dir, 'config.py')  # 生成的配置文件保存的路径

train_data_images = os.path.join(root_path, 'data', 'train', 'images')  # 改成自己训练集图片的目录。
val_data_images = os.path.join(root_path, 'data', 'train', 'images')  # 改成自己验证集图片的目录。
test_data_images = os.path.join(root_path, 'data', 'val', 'images')  # 改成自己测试集图片的目

train_ann_file = os.path.join(root_path, 'data', 'train', 'annotations', 'new_train.json')  # 修改为自己的数据集的训练集json
val_ann_file = os.path.join(root_path, 'data', 'train', 'annotations', 'new_val.json')  # 修改为自己的数据集的验证集json
test_ann_file = os.path.join(root_path, 'data', 'val', 'annotations', 'new_test.json')  # 修改为自己的数据集的验证集json录。

# 去找个网址里找你对应的模型的网址: https://github.com/open-mmlab/mmdetection/blob/master/README_zh-CN.md
load_from = os.path.join(work_dir, 'checkpoint.pth')  # 修改成自己的checkpoint.pth路径

# File config
num_classes = 50  # 改成自己的类别数。
classes = ('1', '2', '3', '4', '5', '6', '7', '8', '9', '10',
           '11', '12', '13', '14', '15', '16', '17', '18', '19',
           '20', '21', '22', '23', '24', '25', '26', '27', '28',
           '29', '30', '31', '32', '33', '34', '35', '36', '37',
           '38', '39', '40', '41', '42', '43', '44', '45', '46',
           '47', '48', '49', '50')  # 改成自己的类别,如果只有一个类别的话,要写成这样定义为元组: classes = ('1', )

###############  下边一些参数包含不全,可以在生成的配置文件中再对其他参数进行修改    #####################

# Train config              # 根据自己的需求对下面进行配置
gpu_ids = range(0, 1)  # 改成自己要用的gpu
gpu_num = 1
total_epochs = 20  # 改成自己想训练的总epoch数
batch_size = 2 ** 1  # 根据自己的显存,改成合适数值,建议是2的倍数。
num_worker = 1  # 比batch_size小,就行
log_interval = 300  # 日志打印的间隔
checkpoint_interval = 7  # 权重文件保存的间隔
lr = 0.02 * batch_size * gpu_num / 16  # 学习率
ratios = [0.5, 1.0, 2.0]
strides = [4, 8, 16, 32, 64]

cfg = Config.fromfile(baseline_cfg_path)

if not os.path.exists(work_dir):
    os.makedirs(work_dir)

cfg.work_dir = work_dir
print("Save config dir:", work_dir)

# swin和mmdetection的训练集配置不在一个地方,那个不报错用哪个:
cfg.classes = classes
# mmdetection用这个:
cfg.data.train.img_prefix = train_data_images
cfg.data.train.classes = classes
cfg.data.train.ann_file = train_ann_file
# swin用这个,注释上边那个
# cfg.data.train.dataset.img_prefix = train_data_images
# cfg.data.train.dataset.classes = classes
# cfg.data.train.dataset.ann_file = train_ann_file

cfg.data.val.img_prefix = val_data_images
cfg.data.val.classes = classes
cfg.data.val.ann_file = val_ann_file

cfg.data.test.img_prefix = test_data_images
cfg.data.test.classes = classes
cfg.data.test.ann_file = test_ann_file

cfg.data.samples_per_gpu = batch_size
cfg.data.workers_per_gpu = num_worker
cfg.log_config.interval = log_interval

# 有些配置文件num_classes可能不在这个地方,生成之后去配置文件里搜索一下,看看都修改了没
for head in cfg.model.roi_head.bbox_head:
    head.num_classes = num_classes
if "mask_head" in cfg.model.roi_head:
    cfg.model.roi_head.mask_head.num_classes = num_classes

cfg.load_from = load_from
cfg.runner.max_epochs = total_epochs
cfg.total_epochs = total_epochs
cfg.optimizer.lr = lr
cfg.checkpoint_config.interval = checkpoint_interval
cfg.model.rpn_head.anchor_generator.ratios = ratios
cfg.model.rpn_head.anchor_generator.strides = strides

cfg.dump(save_cfg_path)
print(save_cfg_path)
print("—" * 50)
print(f'CONFIG:\n{cfg.pretty_text}')
print("—" * 50)

生成配置文件后,路径在 ./work_dirs/cascade_mask_rcnn_r101/config.py,在mmdetection根目录下,又可以愉快的进行训练了。

训练命令(在mmdetection根目录):4GPU训练

./tools/dist_train.sh work_dirs/cascade_mask_rcnn_r101/config.py 4

?? 最终也功夫不负有心人,解决掉了这个bug,写此博客,以帮助大家少走弯路。

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-04-29 12:06:21  更:2022-04-29 12:07:49 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/15 15:45:18-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码