??网上MMDetection的教程看似有很多,但感觉都不成系统,看完一圈下来还是不知道MMDetection要怎么用。这里还是推荐直接跟着官方教程,结合源码学习MMDetection,相关链接汇总如下:
- 官方教程 - MMCV
- 官方教程 - MMDetection
- 官方教程 - 不得不知的 MMDetection 学习路线(个人经验版)
- 西安交大课件 - mmdetection教程(使用篇)
??本文会介绍如何在MMDetection中从头开始搭建一套属于自己的算法。前几篇博客算是本人学习过程中的笔记,从源码本身分析了MMDetection的原理,比较细碎,本篇博客会从宏观的角度重新梳理一下MMDetection的使用方法以及流程原理,算是对之前一个月学习过程的总结。
- MMDetection框架入门教程(一):Anaconda3下的安装教程(mmdet+mmdet3d)
- MMDetection框架入门教程(二):快速上手教程
- MMDetection框架入门教程(三):配置文件详细解析
- MMDetection框架入门教程(四):注册机制详解
- MMDetection框架入门教程(五):Runner和Hook详细解析
1. 框架概述
??MMDetection是商汤和港中文大学针对目标检测任务推出的一个开源项目,它基于Pytorch实现了大量的目标检测算法,把数据集构建、模型搭建、训练策略等过程都封装成了一个个模块,通过模块调用的方式,我们能够以很少的代码量实现一个新算法,大大提高了代码复用率。
??整个MMLab家族除了MMDetection,还包含针对目标跟踪任务的MMTracking,针对3D目标检测任务的MMDetection3D等开源项目,他们都是以Pytorch和MMCV以基础。Pytorch不需要过多介绍,MMCV是一个面向计算机视觉的基础库,最主要作用是提供了基于Pytorch的通用训练框架,比如我们常提到的Registry、Runner、Hook等功能都是在MMCV中支持的。另外,MMCV还提供了通用IO接口、多种CNN网络结构、高质量实现的常见CUDA算子,这里就不进一步展开了。
2. 框架整体流程
2.1 Pytorch
??我们使用Pytorch构建一个新算法时,通常包含如下几步:
- 构建数据集:新建一个类,并继承
Dataset 类,重写__getitem__() 方法实现数据和标签的加载和遍历功能,并以pipeline的方式定义数据预处理流程 - 构建数据加载器:传入相应的参数实例化DataLoader
- 构建模型:新建一个类,并继承
Module 类,重写forward() 函数定义模型的前向过程 - 定义损失函数和优化器:根据算法选择合适和损失函数和优化器
- 训练和验证:循环从DataLoader中获取数据和标签,送入网络模型,计算loss,根据反传的梯度使用优化器进行迭代优化
- 其他操作:在主调函数里可以任意穿插训练Tricks、日志打印、检查点保存等操作
2.2 MMDetection
??使用Pytorch构建一个新算法时,通常包含如下几步:
- 注册数据集:
CustomDataset 是MMDetection在原始的Dataset 基础上的再次封装,其__getitem__() 方法会根据训练和测试模式分别重定向到prepare_train_img() 和prepare_test_img() 函数。用户以继承CustomDataset 类的方式构建自己的数据集时,需要重写load_annotations() 和get_ann_info() 函数,定义数据和标签的加载及遍历方式。完成数据集类的定义后,还需要使用DATASETS.register_module() 进行模块注册。 - 注册模型:模型构建的方式和Pytorch类似,都是新建一个
Module 的子类然后重写forward() 函数。唯一的区别在于MMDetection中需要继承BaseModule 而不是Module ,BaseModule 是Module 的子类,MMLab中的任何模型都必须继承此类。另外,MMDetection将一个完整的模型拆分为backbone、neck和head三部分进行管理,所以用户需要按照这种方式,将算法模型拆解成3个类,分别使用BACKBONES.register_module() 、NECKS.register_module() 和HEADS.register_module() 完成模块注册。 - 构建配置文件:配置文件用于配置算法各个组件的运行参数,大体上可以包含四个部分:datasets、models、schedules和runtime。完成相应模块的定义和注册后,在配置文件中配置好相应的运行参数,然后MMDetection就会通过
Registry 类读取并解析配置文件,完成模块的实例化。另外,配置文件可以通过_base_ 字段实现继承功能,以提高代码复用率。 - 训练和验证:在完成各模块的代码实现、模块的注册、配置文件的编写后,就可以使用
./tools/train.py 和./tools/test.py 对模型进行训练和验证,不需要用户编写额外的代码。
2.3 流程对比
??虽然从步骤上看MMDetection相比Pytorch的算法实现步骤存在挺大差异,但底层的逻辑实现和Pytorch本质上还是一样的,可以参考下图对照着进行理解,其中蓝色部分表示Pytorch流程,橙色部分表示MMDetection流程,绿色部分表示和算法框架无关的通用流程。
3. MMCV的注册机制
??在开始讲算法实现流程之前,必须要先了解MMDetection中的注册机制。
??MMDetection作为MMCV的下游项目,继承了MMCV的模块管理方式——注册机制。简单来说,注册机制就是维护几张查询表,key是模块的名称,value是模块的句柄,每张查询表都管理一批功能相似的不同模块。我们每新建一个模块,都要根据模块实现的功能将对应的key-value 查询对保存到对应的查询表中,这个保存的过程就称为“注册”。当我们想要调用某个模块时,只需要根据模块名称从查询表中找到对应的模块句柄,然后就能完成模块初始化或方法调用等操作。MMCV通过Registry 类来实现字符串(key)到类(value)的映射。
??Registry的构造函数如下所示,变量self._module_dict 就是上面提到的“查询表”,注册的模块都会存到这个字典类型的变量里,新建一个Registry实例就是新建一张查询表。另外,Registry还支持继承机制。
from mmcv.utils import Registry
class Registry:
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
??模块的注册通过Registry的成员函数register_module() 来实现,register_module() 内部又会调用另一个私有函数_register_module() ,模块注册的核心功能其实是在_register_module() 中实现的。核心代码也很简单,就是将传入的module_name 和module_class 保存到字典self._module_dict 中。
def _register_module(self, module_class, module_name=None, force=False):
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered in {self.name}')
self._module_dict[name] = module_class
??在我们通过字符串获取到一个模块的句柄后,可以通过self.build_func 函数句柄来实例化这个模块。build_func 可以人为指定,也可以从父类继承,一般来说都是默认使用build_from_cfg() 函数,即使用配置参数cfg 来初始化该模块。配置参数cfg 是一个字典,里面的type 字段是模块名称的字符串,其他字段则对应模块构造函数的输入参数。
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
raise type(e)(f'{obj_cls.__name__}: {e}')
??考虑到registry 参数需要指向当前注册器本身,我们一般是调用Registry类的build() 方法而不是self.build_func 。
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
??下面是一个小例子,模拟了网络模型的注册和调用过程。注意一下,我们打印Registry对象时,实际上打印的是self._module_dict 中的values。
MODELS = Registry('myModels')
@MODELS.register_module()
class ResNet(object):
def __init__(self, depth):
self.depth = depth
print('Initialize ResNet{}'.format(depth))
class FPN(object):
def __init__(self, in_channel):
self.in_channel= in_channel
print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)
print(MODELS)
""" 打印结果为:
Registry(name=myModels, items={'ResNet': <class '__main__.ResNet'>, 'FPN': <class '__main__.FPN'>})
"""
backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)
my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" 打印结果为:
Initialize ResNet101
Initialize FPN256
<__main__.ResNet object at 0x000001E68E99E198> <__main__.FPN object at 0x000001E695044B38>
"""
4. 算法实现流程
??2.2节提到,使用MMDetection实现一个新算法,包含注册数据集、注册模型、构建配置文件、训练/验证这四个步骤。要理解MMDetection的算法实现流程,必须要吃透Config、Registry、Runner和Hook这四个类。
4.1 注册数据集
??定义自己的数据集时,需要新写一个继承CustomDataset 的Dataset类,然后重写load_annotations() 函数和get_ann_info() 函数。官方文档上说,用户如果要使用CustomDataset ,要将现有数据集转换成MMDetection兼容的格式(COCO格式或中间格式) 。但我看了一下底层的代码并没有发现有这个限制,只要你的数据格式能和你实现的load_annotations() 和get_ann_info() 对应上即可。
"""
中间数据格式:
[
{
'filename': 'a.jpg', # 图片路径
'width': 1280, # 图片尺寸
'height': 720,
'ann': { # 标注信息
'bboxes': <np.ndarray, float32> (n, 4), # 标注框坐标(x1, y1, x2, y2)
'labels': <np.ndarray, int64> (n, ), # 标注框类别
'bboxes_ignore': <np.ndarray, float32> (k, 4), # 不关注的标注框坐标(可选)
'labels_ignore': <np.ndarray, int64> (k, ) # 不关注的标注框类别(可选)
}
},
...
]
"""
class CustomDataset(Dataset):
CLASSES = None
def __init__(self,
ann_file,
pipeline,
classes=None,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True):
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
self.CLASSES = self.get_classes(classes)
self.data_infos = self.load_annotations(self.ann_file)
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
self.pipeline = Compose(pipeline)
??在Pytorch中Dataset 的遍历是通过重写__getitem__() 函数实现的,但MMDetection的CustomDataset 虽然是Dataset 的子类,却没有要求我们重写__getitem__() 函数,原因是为了方便训练模式和测试模式下的数据管理,MMDetection已经重写了__getitem__() 函数,可以根据当前运行模式调用prepare_train_img() 或prepare_test_img() ,两者的区别在于是否加载训练标签。所以我们只需要重写load_annotations() 和get_ann_info() 函数,剩下的部分交给MMDetection就可以了。
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_img(idx)
else:
return self.prepare_train_img(idx)
def prepare_train_img(self, idx):
img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
return self.pipeline(results)
def prepare_test_img(self, idx):
img_info = self.data_infos[idx]
results = dict(img_info=img_info)
return self.pipeline(results)
??完成自定义的Dataset类后别忘记加上@DATASETS.register_module() 将当前模块注册到DATASETS表中。
4.2 注册模型
??网络模型的定义比较简单,相比Pytorch只有两个区别:
- 继承的父类从
Module 变成了BaseModule - 需要按照backbone、neck和head的结构将模型拆解成3个部分,分别定义并注册到
BACKBONES 、NECKS 以及HEADS 当中。
4.3 构建配置文件
??2.2节有提到,在MMDetection框架下,我们不需要另外实现迭代训练/测试流程的代码,只需要执行现成的train.py或test.py即可。但MMDetection怎么知道我们需要哪些模块呢?这就是配置文件起到的作用。
4.3.1 配置文件的构成
??配置文件是由一系列变量定义组成的文本文件,其中dict 类型的变量表示一个个的模块,dict 变量必须包含type 字段,表示模块名称,其它字段则和模块构造函数的参数一一对应,届时用于该模块的初始化(见第本文3章的build_from_cfg() 函数)。该模块必须是已经注册的,否则后续MMDetection无法根据type 值找到对应的模块。配置文件除了dict 类型的变量以外,还可以是其他任意类型,一般是辅助dict 变量定义的中间变量,比如:
test_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='PadMultiViewImage', size_divisor=32)
]
evaluation = dict(interval=2, pipeline=test_pipeline)
??配置文件也支持继承操作,通过_base_ 变量来实现。_base_ 是一个list 类型变量,里面存储的是要继承的配置文件的路径。在解析配置文件的时候,文件解析器以递归的方式(其他配置文件也可能包含_base_ 变量)解析所有配置文件。任何配置文件往上追溯都会继承以下四个文件,分别对应数据集(datasets)、模型(models)、训练策略(schedules)和运行时的默认配置(default_runtime):
_base_ = [
'mmdetection/configs/_base_/models/fast_rcnn_r50_fpn.py',
'mmdetection/configs/_base_/datasets/coco_detection.py',
'mmdetection/configs/_base_/schedules/schedule_1x.py',
'mmdetection/configs/_base_/default_runtime.py',
]
??如果你对上面继承这4个基础配置文件的配置文件进行打印,可以看到如下内容,这也是任何一个完整配置文件都应该包含的配置信息。当然,你也可以任意增加自定义的配置信息。所以我们平常新建一个配置文件的时候,一般都是继承这4个基础配置文件,然后在此基础上进行针对性调整。
model = dict(
type='FastRCNN',
backbone=dict(
type='ResNet',
...,
),
neck=dict(
type='FPN',
...,
),
roi_head=dict(
type='StandardRoIHead',
...,
loss_cls=dict(...),
loss_bbox=dict(...),
),
train_cfg=dict(
assigner=dict(...),
sampler=dict(...),
...
),
test_cfg =dict(
nms=dict(...),
...,
)
)
dataset_type = '...'
data_root = '...'
img_norm_cfg = dict(...)
train_pipeline = [
...,
]
test_pipeline = [...]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipline=trian_pipline,
),
val=dict(
...,
pipline=test_pipline,
),
test=dict(
...,
pipline=test_pipline,
)
)
evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
??另外还有一些可选的配置参数,比如custom_imports ,用于导入用户自定义的模块,当配置文件解析器解析到该字段时,会调用import_modules_from_strings() 函数将字段imports 包含的模块导入到程序中。
custom_imports = dict(imports=['os.path', 'numpy'],
allow_failed_imports=False)
4.3.2 配置文件的修改
??修改配置文件时会遇到2种情况:
- 修改已有dict的某个参数:直接重写对应的参数
- 需要删掉原有dict的所有参数,然后用一组全新的参数代替:增加
_delete_=True 字段
??以修改学习率和更换优化器为例解释这两种情况下应该怎么修改配置文件:
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer = dict(lr=0.001)
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001)
optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001)
4.3.3 配置文件的解析
??解析配置文件其实是train.py和test.py要做的事,这里放到和构建配置文件一块讲了,逻辑上会更通畅一些。
??一般使用Config类来管理配置文件。使用Config.fromfile(filename) 来读取配置文件(也可以直接传入一个dict),返回一个Config类实例cfg,然后可以通过print(cfg.pretty_text) 的方式来打印配置文件信息,或者通过cfg.dump(filepath) 来保存配置文件信息。
from mmcv import Config
cfg = Config.fromfile('../configs/test_config.py')
??fromfile() 函数源码如下,其核心函数是_file2dict() 。_file2dict() 会根据文本顺序,按照key = value的格式解析配置文件,得到一个名为cfg_dict 的字典,如果存在_base_ 字段,还会对_base_ 包含的每个文件路径再调用一次_file2dict() 函数,将文件中包含的配置参数加入到cfg_dict 中,实现配置文件的继承功能。需要注意的是,_file2dict() 内部会对_base_ 中不同文件包含的键值进行校验,不同基础配置文件中不允许出现重复的键值,否则Config不知道以哪个配置文件为准。
def fromfile(filename,
use_predefined_variables=True,
import_custom_modules=True):
cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables)
if import_custom_modules and cfg_dict.get('custom_imports', None):
import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
??调用_file2dict() 解析得到的cfg_dict 格式如下,配置文件中的文本信息全部转换成了变量存储在一个字典类型之中。
??另外有两点需要补充一下,其一是构造Config对象的时候,会将python的dict 数据类型转换为ConfigDict 类型进行处理。ConfigDict 是第三方库addict中Dict 的子类(Dict 又是pythondict 的子类),因为python原生的dict 类型不支持.属性 的访问方式,特别是dict 内部嵌套了多层dict的时候,如果按照key的访问方式,代码写起来非常低效,而Dict 类通过重写__getattr__() 的方式实现了.属性 的访问方式。所以继承了Dict 的ConfigDict 也支持使用.属性 的方式访问字典中的各个成员值。
from mmcv import ConfigDict
model = ConfigDict(dict(backbone=dict(type='ResNet', depth=50)))
print(model.backbone.type)
??其二,为了兼容配置文件名中出现小数点的情况,_file2dict() 会在C盘下创建一个临时文件夹进行操作,如果C盘有访问权限设置,可能会出现报错,不过这个问题只会出现在Windows系统下。
4.3.4 配置文件小结
??简单回顾一下,配置文件是一个包含多个dict 变量的文本文件,每个dict 对应一个具体的模块,dict 必须要有type 字段,其他字段和该模块的构造参数相对应。当对调用build() 函数对模块进行实例化的时候,会根据type 字符串的值从查询表中找到对应的模块句柄,并使用dict 中其他字段的值作为构造参数对该模块进行初始化。
4.4 训练和测试
(明天写)
|