构建模型
mmdection通过读取configs配置文件,创建各个模块。
def parse_args():
'''
not show here
'''
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
if args.work_dir is not None:
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
meta = dict()
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__ + get_git_hash()[:7],
CLASSES=datasets[0].CLASSES)
model.CLASSES = datasets[0].CLASSES
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
cfg字典 参数:从config 下的某个 .py 文件中读取的,该py文件也会继承其他文件,包括模型和数据以及lr训练策略,均可以从这里面修改 可以参见这个博客:(详细讲解了配置文件)
cfg字典包括这些:可以看到包括 model , train_pipelie, test_pipeline , data ,optimizer ,lr_config, checkpoint_config,等等,所以修改参数就从配置文件里面追溯进去;目前model还是dict的形式;还没有形成连起来模型的形式; 创建模型: 调用了build_from_cfg (/opt/conda/envs/open-mmlab2/lib/python3.7/site-packages/mmcv/utils/registry.py)
def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
这里cfg传进来的是 原_cfg.model,即 {backbone,neck,head,train_config,test_config等} ,如果是字典的形式,就Build a PyTorch model
如果是列表形式,就创建 个nn.Sequential; ( 为什么要这么区分呢? 如果有人知道可以告诉我)
Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
追溯进去 build_from_cfg函数:
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
通过registry类得到注册器,然后从注册器中根据key取出相应的类,得到该类后可以正常创建创建实例;
Args:
cfg (dict): Config dict. It should at least contain the key "type". #如果是构建模型,type就是FasterRCNN ;如果是构建数据,就是COCOdataset
registry (:obj:`Registry`): The registry to search the type from. #
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
raise type(e)(f'{obj_cls.__name__}: {e}')
obj_cls 是个类,具体的类的实现在这里可以看到: ’ mmdet.models.detectors.faster_rcnn.FasterRCNN ';进入到这个函数看怎么构建的模型,在这个博客里写不完了。(参见这个链接) 构建完model,就得到了这样的模型;用的是pytorch里的形式;
构建数据
数据集的构建也是利用build_from_cfg函数。(未完待续…)
构建训练器
未完待续…
|