IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> MMEngine理解 -> 正文阅读

[人工智能]MMEngine理解

MMEngine理解

1 简介

  • MMEngine 是一个用于深度学习模型训练的基础库,基于 PyTorch,支持在 Linux、Windows、macOS 上运行。它具有如下三个亮点:
  • 通用:MMEngine 实现了一个高级的通用训练器,它能够:
    • 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 imagenet(pytorch example 400 行)
    • 轻松兼容流行的算法库如 TIMM、TorchVision 和 Detectron2 中的模型
  • 统一:MMEngine 设计了一个接口统一的开放架构,使得
    • 用户可以仅依赖一份代码实现所有任务的轻量化,例如 MMRazor 1.x 相比 MMRazor 0.x 优化了 40% 的代码量
    • 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。
  • 灵活:MMEngine 实现了“乐高”式的训练流程,支持了
    • 根据迭代数、 loss 和评测结果等动态调整的训练流程、优化策略和数据增强策略,例如早停(early stopping)机制等
    • 任意形式的模型权重平均,如 Exponential Momentum Average (EMA) 和 Stochastic Weight Averaging (SWA)
    • 训练过程中针对任意数据和任意节点的灵活可视化和日志控制
    • 对神经网络模型中各个层的优化配置进行细粒度调整
    • 混合精度训练的灵活控制

1.1 架构

在这里插入图片描述

  • 上图展示了 MMEngine 在 OpenMMLab 2.0 中的层次。MMEngine 实现了 OpenMMLab 算法库的新一代训练架构,为 OpenMMLab 中的 30 多个算法库提供了统一的执行基座。其核心组件包含训练引擎、评测引擎和模块管理等。

1.2 模块介绍

在这里插入图片描述

  • MMEngine 将训练过程中涉及的组件和它们的关系进行了抽象,如上图所示。不同算法库中的同类型组件具有相同的接口定义。

1.2.1 核心模块与相关组件

  • 训练引擎的核心模块是执行器(Runner)。 执行器负责执行训练、测试和推理任务并管理这些过程中所需要的各个组件。在训练、测试、推理任务执行过程中的特定位置,执行器设置了钩子(Hook) 来允许用户拓展、插入和执行自定义逻辑。执行器主要调用如下组件来完成训练和推理过程中的循环:
    • 数据集(Dataset):负责在训练、测试、推理任务中构建数据集,并将数据送给模型。实际使用过程中会被数据加载器(DataLoader)封装一层,数据加载器会启动多个子进程来加载数据。
    • 模型(Model):在训练过程中接受数据并输出 loss;在测试、推理任务中接受数据,并进行预测。分布式训练等情况下会被模型的封装器(Model Wrapper,如MMDistributedDataParallel)封装一层。
    • 优化器封装(Optimizer):优化器封装负责在训练过程中执行反向传播优化模型,并且以统一的接口支持了混合精度训练和梯度累加。
    • 参数调度器(Parameter Scheduler):训练过程中,对学习率、动量等优化器超参数进行动态调整。
  • 在训练间隙或者测试阶段,评测指标与评测器(Metrics & Evaluator)会负责对模型性能进行评测。其中评测器负责基于数据集对模型的预测进行评估。评测器内还有一层抽象是评测指标,负责计算具体的一个或多个评测指标(如召回率、正确率等)。
  • 在训练、推理执行过程中,上述各个组件都可以调用日志管理模块和可视化器进行结构化和非结构化日志的存储与展示。日志管理(Logging Modules):负责管理执行器运行过程中产生的各种日志信息。其中消息枢纽 (MessageHub)负责实现组件与组件、执行器与执行器之间的数据共享,日志处理器(Log Processor)负责对日志信息进行处理,处理后的日志会分别发送给执行器的日志器(Logger)和可视化器(Visualizer)进行日志的管理与展示。可视化器(Visualizer):可视化器负责对模型的特征图、预测结果和训练过程中产生的结构化日志进行可视化,支持 Tensorboard 和 WanDB 等多种可视化后端。

1.2.1 公共基础模块

  • MMEngine 中还实现了各种算法模型执行过程中需要用到的公共基础模块,包括:
    • 配置类(Config):在 OpenMMLab 算法库中,用户可以通过编写 config 来配置训练、测试过程以及相关的组件。
    • 注册器(Registry):负责管理算法库中具有相同功能的模块。MMEngine 根据对算法库模块的抽象,定义了一套根注册器,算法库中的注册器可以继承自这套根注册器,实现模块的跨算法库调用。
    • 文件读写(File I/O):为各个模块的文件读写提供了统一的接口,以统一的形式支持了多种文件读写后端和多种文件格式,并具备扩展性。
    • 分布式通信原语(Distributed Communication Primitives):负责在程序分布式运行过程中不同进程间的通信。这套接口屏蔽了分布式和非分布式环境的区别,同时也自动处理了数据的设备和通信后端。
    • 其他工具(Utils):还有一些工具性的模块,如 ManagerMixin,它实现了一种全局变量的创建和获取方式,执行器内很多全局可见对象的基类就是 ManagerMixin。

2 上手示例

  • 以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、 可配置的训练和验证流程

2.1 构建模型

  • 首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel,并且其 forward 方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode:对于训练,我们需要 mode 接受字符串 “loss”,并返回一个包含 “loss” 字段的字典;对于验证,我们需要 mode 接受字符串 “predict”,并返回同时包含预测信息和真实信息的结果。
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels

2.2 构建数据集和数据加载器

  • 其次,我们需要构建训练和验证所需要的数据集 (Dataset)和数据加载器 (DataLoader)。 对于基础的训练和验证功能,我们可以直接使用符合 PyTorch 标准的数据加载器和数据集。
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))

val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))

2.3 构建评测指标

  • 为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric,并实现 process 和 compute_metrics 方法。其中 process 方法接受数据集的输出和模型 mode=“predict” 时的输出,此时的数据为一个批次的数据,对这一批次的数据进行处理后,保存信息至 self.results 属性。 而 compute_metrics 接受 results 参数,这一参数的输入为 process 中保存的所有信息 (如果是分布式环境,results 中为已收集的,包括各个进程 process 保存信息的结果),利用这些信息计算并返回保存有评测指标结果的字典。
from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # 将一个批次的中间结果保存至 `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # 返回保存有评测指标结果的字典,其中键为指标名称
        return dict(accuracy=100 * total_correct / total_size)

2.4 构建执行器并执行任务

  • 最后,我们利用构建好的模型,数据加载器,评测指标构建一个执行器 (Runner),同时在其中配置 优化器、工作路径、训练与验证配置等选项,即可通过调用 train() 接口启动训练:
from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    # 用以训练和验证的模型,需要满足特定的接口需求
    model=MMResNet50(),
    # 工作路径,用以保存训练日志、权重文件信息
    work_dir='./work_dir',
    # 训练数据加载器,需要满足 PyTorch 数据加载器协议
    train_dataloader=train_dataloader,
    # 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # 训练配置,用于指定训练周期、验证间隔等信息
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    # 验证数据加载器,需要满足 PyTorch 数据加载器协议
    val_dataloader=val_dataloader,
    # 验证配置,用于指定验证所需要的额外参数
    val_cfg=dict(),
    # 用于验证的评测器,这里使用默认评测器,并评测指标
    val_evaluator=dict(type=Accuracy),
)

runner.train()

3. 基础模块

3.1 注册器(Registry)

  • OpenMMLab 的算法库支持了丰富的算法和数据集,因此实现了很多功能相近的模块。例如 ResNet 和 SE-ResNet 的算法实现分别基于 ResNet 和 SEResNet 类,这些类有相似的功能和接口,都属于算法库中的模型组件。 为了管理这些功能相似的模块,MMEngine 实现了 注册器。 OpenMMLab 大多数算法库均使用注册器来管理它们的代码模块,包括 MMDetection, MMDetection3D,MMPose, MMClassification 和 MMEditing 等。

3.1.1 什么是注册器

  • MMEngine 实现的注册器可以看作一个映射表和模块构建方法(build function)的组合
    • 映射表:维护了一个字符串到类或者函数的映射,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 “ResNet” 到 ResNet 类或函数的映射,使得用户可以通过 “ResNet” 找到 ResNet 类;
    • 模块构建方法:定义了如何根据字符串查找到对应的类或函数以及如何实例化这个类或者调用这个函数,例如,通过字符串 “bn” 找到 nn.BatchNorm2d 并实例化 BatchNorm2d 模块;又或者通过字符串 “build_batchnorm2d” 找到 build_batchnorm2d 函数并返回该函数的调用结果。
    • MMEngine 中的注册器默认使用 build_from_cfg 函数来查找并实例化字符串对应的类或者函数。
  • 一个注册器:管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 MODELS 可以被视作所有模型的抽象,管理了 ResNet, SEResNet 和 RegNetX 等分类网络的类以及 build_ResNet, build_SEResNet 和 build_RegNetX 等分类网络的构建函数。
  • 注册器的定义(部分代码)
class Registry:
    """A registry to map strings to classes or functions.

    Registered object could be built from registry. Meanwhile, registered
    functions could be called from registry.

    Args:
        name (str): Registry name.
        build_func (callable, optional): A function to construct instance
            from Registry. :func:`build_from_cfg` is used if neither ``parent``
            or ``build_func`` is specified. If ``parent`` is specified and
            ``build_func`` is not given,  ``build_func`` will be inherited
            from ``parent``. Defaults to None.
        parent (:obj:`Registry`, optional): Parent registry. The class
            registered in children registry could be built from parent.
            Defaults to None.
        scope (str, optional): The scope of registry. It is the key to search
            for children registry. If not specified, scope will be the name of
            the package where class is defined, e.g. mmdet, mmcls, mmseg.
            Defaults to None.

    Examples:
        >>> # define a registry
        >>> MODELS = Registry('models')
        >>> # registry the `ResNet` to `MODELS`
        >>> @MODELS.register_module()
        >>> class ResNet:
        >>>     pass
        >>> # build model from `MODELS`
        >>> resnet = MODELS.build(dict(type='ResNet'))
        >>> @MODELS.register_module()
        >>> def resnet50():
        >>>     pass
        >>> resnet = MODELS.build(dict(type='resnet50'))

        >>> # hierarchical registry
        >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
        >>> @DETECTORS.register_module()
        >>> class FasterRCNN:
        >>>     pass
        >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))

    More advanced usages can be found at
    https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
    """

    def __init__(self,
                 name: str,
                 build_func: Optional[Callable] = None,
                 parent: Optional['Registry'] = None,
                 scope: Optional[str] = None):
        from .build_functions import build_from_cfg
        self._name = name
        self._module_dict: Dict[str, Type] = dict()
        self._children: Dict[str, 'Registry'] = dict()

        if scope is not None:
            assert isinstance(scope, str)
            self._scope = scope
        else:
            self._scope = self.infer_scope()

        # See https://mypy.readthedocs.io/en/stable/common_issues.html#
        # variables-vs-type-aliases for the use
        self.parent: Optional['Registry']
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_child(self)
            self.parent = parent
        else:
            self.parent = None

        # self.build_func will be set with the following priority:
        # 1. build_func
        # 2. parent.build_func
        # 3. build_from_cfg
        self.build_func: Callable
        if build_func is None:
            if self.parent is not None:
                self.build_func = self.parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func

3.1.2 使用流程

  • 使用注册器管理代码库中的模块,需要以下三个步骤:
    • 创建注册器
    • 创建一个用于实例化类的构建方法(可选,在大多数情况下可以只使用默认方法)
    • 将模块加入注册器中
  • 假设我们要实现一系列激活模块并且希望仅修改配置就能够使用不同的激活模块而无需修改代码。

3.1.2.1 创建注册器

from mmengine import Registry
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
ACTIVATION = Registry('activation', scope='mmengine')

3.1.2.2 定义要注册的模块(类或函数)

import torch.nn as nn

# 使用注册器管理模块
@ACTIVATION.register_module()
class Sigmoid(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        print('call Sigmoid.forward')
        return x

@ACTIVATION.register_module()
class ReLU(nn.Module):
    def __init__(self, inplace=False):
        super().__init__()

    def forward(self, x):
        print('call ReLU.forward')
        return x

@ACTIVATION.register_module()
class Softmax(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        print('call Softmax.forward')
        return x
  • 使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 ACTIVATION 中。通过 @ACTIVATION.register_module() 装饰所实现的模块,字符串和类或函数之间的映射就可以由 ACTIVATION 构建和维护,我们也可以通过 ACTIVATION.register_module(module=ReLU) 实现同样的功能。
  • 通过注册,我们就可以通过 ACTIVATION 建立字符串与类或函数之间的映射:
print(ACTIVATION.module_dict)
# {
#     'Sigmoid': __main__.Sigmoid,
#     'ReLU': __main__.ReLU,
#     'Softmax': __main__.Softmax
# }
  • 只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 custom_imports 字段动态导入该模块进而触发注册机制,详情见导入自定义 Python 模块。

3.1.2.3 通过配置激活模块

  • 模块成功注册后,我们可以通过配置文件使用这个激活模块。
import torch
input = torch.randn(2)

act_cfg = dict(type='Sigmoid')
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# call Sigmoid.forward
print(output)

#如果我们想使用 ReLU,仅需修改配置。
act_cfg = dict(type='ReLU', inplace=True)
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# call ReLU.forward
print(output)

3.1.3 跨项目调用

  • MMEngine 的注册器支持层级注册,利用该功能可实现跨项目调用,即可以在一个项目中使用另一个项目的模块。虽然跨项目调用也有其他方法的可以实现,但 MMEngine 注册器提供了更为简便的方法。
  • 为了方便跨库调用,MMEngine 提供了 20 个根注册器
    • RUNNERS: Runner 的注册器
    • RUNNER_CONSTRUCTORS: Runner 的构造器
    • LOOPS: 管理训练、验证以及测试流程,如 EpochBasedTrainLoop
    • HOOKS: 钩子,如 CheckpointHook, ParamSchedulerHook
    • DATASETS: 数据集
    • DATA_SAMPLERS: DataLoader 的 Sampler,用于采样数据
    • TRANSFORMS: 各种数据预处理,如 Resize, Reshape
    • MODELS: 模型的各种模块
    • MODEL_WRAPPERS: 模型的包装器,如 MMDistributedDataParallel,用于对分布式数据并行
    • WEIGHT_INITIALIZERS: 权重初始化的工具
    • OPTIMIZERS: 注册了 PyTorch 中所有的 Optimizer 以及自定义的 Optimizer
    • OPTIM_WRAPPER: 对 Optimizer 相关操作的封装,如 OptimWrapper,AmpOptimWrapper
    • OPTIM_WRAPPER_CONSTRUCTORS: optimizer wrapper 的构造器
    • PARAM_SCHEDULERS: 各种参数调度器,如 MultiStepLR
    • METRICS: 用于计算模型精度的评估指标,如 Accuracy
    • EVALUATOR: 用于计算模型精度的一个或多个评估指标
    • TASK_UTILS: 任务强相关的一些组件,如 AnchorGenerator, BboxCoder
    • VISUALIZERS: 管理绘制模块,如 DetVisualizer 可在图片上绘制预测框
    • VISBACKENDS: 存储训练日志的后端,如 LocalVisBackend, TensorboardVisBackend
    • LOG_PROCESSORS: 控制日志的统计窗口和统计方法,默认使用 LogProcessor,如有特殊需求可自定义 LogProcessor

3.2 配置(Config)

  • MMEngine 实现了抽象的配置类(Config),为用户提供统一的配置访问接口。配置类能够支持不同格式的配置文件,包括 python,json,yaml,用户可以根据需求选择自己偏好的格式。配置类提供了类似字典或者 Python 对象属性的访问接口,用户可以十分自然地进行配置字段的读取和修改。为了方便算法框架管理配置文件,配置类也实现了一些特性,例如配置文件的字段继承等。

3.2.1 配置文件读取

  • 配置类提供了统一的接口 Config.fromfile(),来读取和解析配置文件。
  • 合法的配置文件应该定义一系列键值对,这里举几个不同格式配置文件的例子。
  • Python 格式:
test_int = 1
test_list = [1, 2, 3]
test_dict = dict(key1='value1', key2=0.1)
  • Json 格式:
{
  "test_int": 1,
  "test_list": [1, 2, 3],
  "test_dict": {"key1": "value1", "key2": 0.1}
}
  • YAML 格式:
test_int: 1
test_list: [1, 2, 3]
test_dict:
  key1: "value1"
  key2: 0.1
  • 对于以上三种格式的文件,假设文件名分别为 config.py,config.json,config.yml,调用 Config.fromfile(‘config.xxx’) 接口加载这三个文件都会得到相同的结果,构造了包含 3 个字段的配置对象。我们以 config.py 为例,我们先将示例配置文件下载到本地:
from mmengine.config import Config
cfg = Config.fromfile('learn_read_config.py')
print(cfg)
  • 输出结果为:
Config (path: learn_read_config.py): {'test_int': 1, 'test_list': [1, 2, 3], 'test_dict': {'key1': 'value1', 'key2': 0.1}}

3.2.2 配置文件的使用

  • 通过读取配置文件来初始化配置对象后,就可以像使用普通字典或者 Python 类一样来使用这个变量了。 我们提供了两种访问接口,即类似字典的接口 cfg[‘key’] 或者类似 Python 对象属性的接口 cfg.key。这两种接口都支持读写。
print(cfg.test_int)
print(cfg.test_list)
print(cfg.test_dict)
cfg.test_int = 2

print(cfg['test_int'])
print(cfg['test_list'])
print(cfg['test_dict'])
cfg['test_list'][1] = 3
print(cfg['test_list'])
  • 输出结果为:
1
[1, 2, 3]
{'key1': 'value1', 'key2': 0.1}
2
[1, 2, 3]
{'key1': 'value1', 'key2': 0.1}
[1, 3, 3]
  • 注意,配置文件中定义的嵌套字段(即类似字典的字段),在 Config 中会将其转化为 ConfigDict 类,该类继承了 Python 内置字典类型的全部接口,同时也支持以对象属性的方式访问数据。
  • 在算法库中,可以将配置与注册器结合起来使用,达到通过配置文件来控制模块构造的目的。这里举一个在配置文件中定义优化器的例子。
  • 假设我们已经定义了一个优化器的注册器 OPTIMIZERS,包括了各种优化器。那么首先写一个 config_sgd.py:
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
  • 然后在算法库中可以通过如下代码构造优化器对象。
from mmengine import Config, optim
from mmengine.registry import OPTIMIZERS # 在mmengine root.py中定义
import torch.nn as nn

cfg = Config.fromfile('config_sgd.py')
model = nn.Conv2d(1, 1, 1)
cfg.optimizer.params = model.parameters()
optimizer = OPTIMIZERS.build(cfg.optimizer)
print(optimizer)
  • 输出结果为:
SGD (
Parameter Group 0
    dampening: 0
    foreach: None
    lr: 0.1
    maximize: False
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001
)

3.2.3 配置文件的继承

  • 有时候,两个不同的配置文件之间的差异很小,可能仅仅只改了一个字段,我们就需要将所有内容复制粘贴一次,而且在后续观察的时候,不容易定位到具体差异的字段。又有些情况下,多个配置文件可能都有相同的一批字段,我们不得不在这些配置文件中进行复制粘贴,给后续的修改和维护带来了不便。
  • 为了解决这些问题,我们给配置文件增加了继承的机制,即一个配置文件 A 可以将另一个配置文件 B 作为自己的基础,直接继承了 B 中所有字段,而不必显式复制粘贴。
  • 这里我们举一个例子来说明继承机制。定义如下两个配置文件
  • optimizer_cfg.py:
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
  • resnet50.py:
_base_ = ['optimizer_cfg.py']
model = dict(type='ResNet', depth=50)
  • 虽然我们在 resnet50.py 中没有定义 optimizer 字段,但由于我们写了 _base_ = [‘optimizer_cfg.py’],会使这个配置文件获得 optimizer_cfg.py 中的所有字段。
cfg = Config.fromfile('resnet50.py')
print(cfg.optimizer)
  • 输出结果为:
{'type': 'SGD', 'lr': 0.02, 'momentum': 0.9, 'weight_decay': 0.0001}
  • 这里 _base_ 是配置文件的保留字段,指定了该配置文件的继承来源。支持继承多个文件,将同时获得这多个文件中的所有字段,但是要求继承的多个文件中没有相同名称的字段,否则会报错。
  • runtime_cfg.py:
gpu_ids = [0, 1]
  • resnet50_runtime.py:
_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)
  • 这时,读取配置文件 resnet50_runtime.py 会获得 3 个字段 model,optimizer,gpu_ids。
cfg = Config.fromfile('resnet50_runtime.py')
print(cfg.optimizer)
  • 输出结果为:
{'type': 'SGD', 'lr': 0.02, 'momentum': 0.9, 'weight_decay': 0.0001}
  • 通过这种方式,我们可以将配置文件进行拆分,定义一些通用配置文件,在实际配置文件中继承各种通用配置文件,可以减少具体任务的配置流程。

3.2.4 修改继承字段

  • 有时候,我们继承一个配置文件之后,可能需要对其中个别字段进行修改,例如继承了 optimizer_cfg.py 之后,想将学习率从 0.02 修改为 0.01。
  • 这时候,只需要在新的配置文件中,重新定义一下需要修改的字段即可。注意由于 optimizer 这个字段是一个字典,我们只需要重新定义这个字典里面需修改的下级字段即可。这个规则也适用于增加一些下级字段。
  • resnet50_lr0.01.py:
_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)
optimizer = dict(lr=0.01)
  • 读取这个配置文件之后,就可以得到期望的结果。
cfg = Config.fromfile('resnet50_lr0.01.py')
print(cfg.optimizer)
  • 输出结果为:
{'type': 'SGD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0.0001}
  • 对于非字典类型的字段,例如整数,字符串,列表等,重新定义即可完全覆盖,例如下面的写法就将 gpu_ids 这个字段的值修改成了 [0]。
_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)
gpu_ids = [0]

3.2.5 删除字典中的 key

  • 有时候我们对于继承过来的字典类型字段,不仅仅是想修改其中某些 key,可能还需要删除其中的一些 key。这时候在重新定义这个字典时,需要指定 delete=True,表示将没有在新定义的字典中出现的 key 全部删除。
  • resnet50_delete_key.py:
_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)
optimizer = dict(_delete_=True, type='SGD', lr=0.01)
  • 这时候,optimizer 这个字典中就只有 type 和 lr 这两个 key,momentum 和 weight_decay 将不再被继承。
cfg = Config.fromfile('resnet50_delete_key.py')
print(cfg.optimizer)
  • 输出结果为:
{'type': 'SGD', 'lr': 0.01}

3.3 执行器(Runner)

  • 深度学习算法的训练、验证和测试通常都拥有相似的流程,因此 MMEngine 提供了执行器以帮助用户简化这些任务的实现流程。 用户只需要准备好模型训练、验证、测试所需要的模块构建执行器,便能够通过简单调用执行器的接口来完成这些任务。用户如果需要使用这几项功能中的某一项,只需要准备好对应功能所依赖的模块即可。
  • 构建模块的方式
    • 手动构建这些模块的实例
    • 通过编写配置文件,由执行器自动从注册器中构建所需要的模块
  • Runner接口
@RUNNERS.register_module()
class Runner:
    """A training helper for PyTorch.

    Runner object can be built from config by ``runner = Runner.from_cfg(cfg)``
    where the ``cfg`` usually contains training, validation, and test-related
    configurations to build corresponding components. We usually use the
    same config to launch training, testing, and validation tasks. However,
    only some of these components are necessary at the same time, e.g.,
    testing a model does not need training or validation-related components.

    To avoid repeatedly modifying config, the construction of ``Runner`` adopts
    lazy initialization to only initialize components when they are going to be
    used. Therefore, the model is always initialized at the beginning, and
    training, validation, and, testing related components are only initialized
    when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``,
    respectively.

    Args:
        model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
            a dict used for build a model.
        work_dir (str): The working directory to save checkpoints. The logs
            will be saved in the subdirectory of `work_dir` named
            :attr:`timestamp`.
        train_dataloader (Dataloader or dict, optional): A dataloader object or
            a dict to build a dataloader. If ``None`` is given, it means
            skipping training steps. Defaults to None.
            See :meth:`build_dataloader` for more details.
        val_dataloader (Dataloader or dict, optional): A dataloader object or
            a dict to build a dataloader. If ``None`` is given, it means
            skipping validation steps. Defaults to None.
            See :meth:`build_dataloader` for more details.
        test_dataloader (Dataloader or dict, optional): A dataloader object or
            a dict to build a dataloader. If ``None`` is given, it means
            skipping test steps. Defaults to None.
            See :meth:`build_dataloader` for more details.
        train_cfg (dict, optional): A dict to build a training loop. If it does
            not provide "type" key, it should contain "by_epoch" to decide
            which type of training loop :class:`EpochBasedTrainLoop` or
            :class:`IterBasedTrainLoop` should be used. If ``train_cfg``
            specified, :attr:`train_dataloader` should also be specified.
            Defaults to None. See :meth:`build_train_loop` for more details.
        val_cfg (dict, optional): A dict to build a validation loop. If it does
            not provide "type" key, :class:`ValLoop` will be used by default.
            If ``val_cfg`` specified, :attr:`val_dataloader` should also be
            specified. If ``ValLoop`` is built with `fp16=True``,
            ``runner.val()`` will be performed under fp16 precision.
            Defaults to None. See :meth:`build_val_loop` for more details.
        test_cfg (dict, optional): A dict to build a test loop. If it does
            not provide "type" key, :class:`TestLoop` will be used by default.
            If ``test_cfg`` specified, :attr:`test_dataloader` should also be
            specified. If ``ValLoop`` is built with `fp16=True``,
            ``runner.val()`` will be performed under fp16 precision.
            Defaults to None. See :meth:`build_test_loop` for more details.
        auto_scale_lr (dict, Optional): Config to scale the learning rate
            automatically. It includes ``base_batch_size`` and ``enable``.
            ``base_batch_size`` is the batch size that the optimizer lr is
            based on. ``enable`` is the switch to turn on and off the feature.
        optim_wrapper (OptimWrapper or dict, optional):
            Computing gradient of model parameters. If specified,
            :attr:`train_dataloader` should also be specified. If automatic
            mixed precision or gradient accmulation
            training is required. The type of ``optim_wrapper`` should be
            AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for
            examples. Defaults to None.
        param_scheduler (_ParamScheduler or dict or list, optional):
            Parameter scheduler for updating optimizer parameters. If
            specified, :attr:`optimizer` should also be specified.
            Defaults to None.
            See :meth:`build_param_scheduler` for examples.
        val_evaluator (Evaluator or dict or list, optional): A evaluator object
            used for computing metrics for validation. It can be a dict or a
            list of dict to build a evaluator. If specified,
            :attr:`val_dataloader` should also be specified. Defaults to None.
        test_evaluator (Evaluator or dict or list, optional): A evaluator
            object used for computing metrics for test steps. It can be a dict
            or a list of dict to build a evaluator. If specified,
            :attr:`test_dataloader` should also be specified. Defaults to None.
        default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks to
            execute default actions like updating model parameters and saving
            checkpoints. Default hooks are ``OptimizerHook``,
            ``IterTimerHook``, ``LoggerHook``, ``ParamSchedulerHook`` and
            ``CheckpointHook``. Defaults to None.
            See :meth:`register_default_hooks` for more details.
        custom_hooks (list[dict] or list[Hook], optional): Hooks to execute
            custom actions like visualizing images processed by pipeline.
            Defaults to None.
        data_preprocessor (dict, optional): The pre-process config of
            :class:`BaseDataPreprocessor`. If the ``model`` argument is a dict
            and doesn't contain the key ``data_preprocessor``, set the argument
            as the ``data_preprocessor`` of the ``model`` dict.
            Defaults to None.
        load_from (str, optional): The checkpoint file to load from.
            Defaults to None.
        resume (bool): Whether to resume training. Defaults to False. If
            ``resume`` is True and ``load_from`` is None, automatically to
            find latest checkpoint from ``work_dir``. If not found, resuming
            does nothing.
        launcher (str): Way to launcher multi-process. Supported launchers
            are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' is provided,
            non-distributed environment will be launched.
        env_cfg (dict): A dict used for setting environment. Defaults to
            dict(dist_cfg=dict(backend='nccl')).
        log_processor (dict, optional): A processor to format logs. Defaults to
            None.
        log_level (int or str): The log level of MMLogger handlers.
            Defaults to 'INFO'.
        visualizer (Visualizer or dict, optional): A Visualizer object or a
            dict build Visualizer object. Defaults to None. If not
            specified, default config will be used.
        default_scope (str): Used to reset registries location.
            Defaults to "mmengine".
        randomness (dict): Some settings to make the experiment as reproducible
            as possible like seed and deterministic.
            Defaults to ``dict(seed=None)``. If seed is None, a random number
            will be generated and it will be broadcasted to all other processes
            if in distributed environment. If ``cudnn_benchmarch`` is
            ``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
            ``randomness``, the value of ``torch.backends.cudnn.benchmark``
            will be ``False`` finally.
        experiment_name (str, optional): Name of current experiment. If not
            specified, timestamp will be used as ``experiment_name``.
            Defaults to None.
        cfg (dict or Configdict or :obj:`Config`, optional): Full config.
            Defaults to None.

    Examples:
        >>> from mmengine.runner import Runner
        >>> cfg = dict(
        >>>     model=dict(type='ToyModel'),
        >>>     work_dir='path/of/work_dir',
        >>>     train_dataloader=dict(
        >>>     dataset=dict(type='ToyDataset'),
        >>>     sampler=dict(type='DefaultSampler', shuffle=True),
        >>>     batch_size=1,
        >>>     num_workers=0),
        >>>     val_dataloader=dict(
        >>>         dataset=dict(type='ToyDataset'),
        >>>         sampler=dict(type='DefaultSampler', shuffle=False),
        >>>        batch_size=1,
        >>>        num_workers=0),
        >>>     test_dataloader=dict(
        >>>         dataset=dict(type='ToyDataset'),
        >>>         sampler=dict(type='DefaultSampler', shuffle=False),
        >>>         batch_size=1,
        >>>         num_workers=0),
        >>>     auto_scale_lr=dict(base_batch_size=16, enable=False),
        >>>     optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict(
        >>>         type='SGD', lr=0.01)),
        >>>     param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
        >>>     val_evaluator=dict(type='ToyEvaluator'),
        >>>     test_evaluator=dict(type='ToyEvaluator'),
        >>>     train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1),
        >>>     val_cfg=dict(),
        >>>     test_cfg=dict(),
        >>>     custom_hooks=[],
        >>>     default_hooks=dict(
        >>>         timer=dict(type='IterTimerHook'),
        >>>         checkpoint=dict(type='CheckpointHook', interval=1),
        >>>         logger=dict(type='LoggerHook'),
        >>>         optimizer=dict(type='OptimizerHook', grad_clip=False),
        >>>         param_scheduler=dict(type='ParamSchedulerHook')),
        >>>     launcher='none',
        >>>     env_cfg=dict(dist_cfg=dict(backend='nccl')),
        >>>     log_processor=dict(window_size=20),
        >>>     visualizer=dict(type='Visualizer',
        >>>     vis_backends=[dict(type='LocalVisBackend',
        >>>                        save_dir='temp_dir')])
        >>>    )
        >>> runner = Runner.from_cfg(cfg)
        >>> runner.train()
        >>> runner.test()
    """
    cfg: Config
    _train_loop: Optional[Union[BaseLoop, Dict]]
    _val_loop: Optional[Union[BaseLoop, Dict]]
    _test_loop: Optional[Union[BaseLoop, Dict]]

    def __init__(
        self,
        model: Union[nn.Module, Dict],
        work_dir: str,
        train_dataloader: Optional[Union[DataLoader, Dict]] = None,
        val_dataloader: Optional[Union[DataLoader, Dict]] = None,
        test_dataloader: Optional[Union[DataLoader, Dict]] = None,
        train_cfg: Optional[Dict] = None,
        val_cfg: Optional[Dict] = None,
        test_cfg: Optional[Dict] = None,
        auto_scale_lr: Optional[Dict] = None,
        optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,
        param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
        val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
        test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
        default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
        custom_hooks: Optional[List[Union[Hook, Dict]]] = None,
        data_preprocessor: Union[nn.Module, Dict, None] = None,
        load_from: Optional[str] = None,
        resume: bool = False,
        launcher: str = 'none',
        env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
        log_processor: Optional[Dict] = None,
        log_level: str = 'INFO',
        visualizer: Optional[Union[Visualizer, Dict]] = None,
        default_scope: str = 'mmengine',
        randomness: Dict = dict(seed=None),
        experiment_name: Optional[str] = None,
        cfg: Optional[ConfigType] = None,
    ):
        self._work_dir = osp.abspath(work_dir)
        mmengine.mkdir_or_exist(self._work_dir)

        # recursively copy the `cfg` because `self.cfg` will be modified
        # everywhere.
        if cfg is not None:
            if isinstance(cfg, Config):
                self.cfg = copy.deepcopy(cfg)
            elif isinstance(cfg, dict):
                self.cfg = Config(cfg)
        else:
            self.cfg = Config(dict())

        # lazy initialization
        training_related = [train_dataloader, train_cfg, optim_wrapper]
        if not (all(item is None for item in training_related)
                or all(item is not None for item in training_related)):
            raise ValueError(
                'train_dataloader, train_cfg, and optimizer should be either '
                'all None or not None, but got '
                f'train_dataloader={train_dataloader}, '
                f'train_cfg={train_cfg}, '
                f'optim_wrapper={optim_wrapper}.')
        self._train_dataloader = train_dataloader
        self._train_loop = train_cfg

        self.optim_wrapper: Optional[Union[OptimWrapper, dict]]
        self.optim_wrapper = optim_wrapper

        self.auto_scale_lr = auto_scale_lr

        # If there is no need to adjust learning rate, momentum or other
        # parameters of optimizer, param_scheduler can be None
        if param_scheduler is not None and self.optim_wrapper is None:
            raise ValueError(
                'param_scheduler should be None when optimizer is None, '
                f'but got {param_scheduler}')

        # Parse `param_scheduler` to a list or a dict. If `optim_wrapper` is a
        # `dict` with single optimizer, parsed param_scheduler will be a
        # list of parameter schedulers. If `optim_wrapper` is
        # a `dict` with multiple optimizers, parsed `param_scheduler` will be
        # dict with multiple list of parameter schedulers.
        self._check_scheduler_cfg(param_scheduler)
        self.param_schedulers = param_scheduler

        val_related = [val_dataloader, val_cfg, val_evaluator]
        if not (all(item is None
                    for item in val_related) or all(item is not None
                                                    for item in val_related)):
            raise ValueError(
                'val_dataloader, val_cfg, and val_evaluator should be either '
                'all None or not None, but got '
                f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
                f'val_evaluator={val_evaluator}')
        self._val_dataloader = val_dataloader
        self._val_loop = val_cfg
        self._val_evaluator = val_evaluator

        test_related = [test_dataloader, test_cfg, test_evaluator]
        if not (all(item is None for item in test_related)
                or all(item is not None for item in test_related)):
            raise ValueError(
                'test_dataloader, test_cfg, and test_evaluator should be '
                'either all None or not None, but got '
                f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
                f'test_evaluator={test_evaluator}')
        self._test_dataloader = test_dataloader
        self._test_loop = test_cfg
        self._test_evaluator = test_evaluator

        self._launcher = launcher
        if self._launcher == 'none':
            self._distributed = False
        else:
            self._distributed = True

        # self._timestamp will be set in the `setup_env` method. Besides,
        # it also will initialize multi-process and (or) distributed
        # environment.
        self.setup_env(env_cfg)
        # self._deterministic and self._seed will be set in the
        # `set_randomness`` method
        self._randomness_cfg = randomness
        self.set_randomness(**randomness)

        if experiment_name is not None:
            self._experiment_name = f'{experiment_name}_{self._timestamp}'
        elif self.cfg.filename is not None:
            filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0]
            self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
        else:
            self._experiment_name = self.timestamp
        self._log_dir = osp.join(self.work_dir, self.timestamp)
        mmengine.mkdir_or_exist(self._log_dir)
        # Used to reset registries location. See :meth:`Registry.build` for
        # more details.
        self.default_scope = DefaultScope.get_instance(
            self._experiment_name, scope_name=default_scope)
        # Build log processor to format message.
        log_processor = dict() if log_processor is None else log_processor
        self.log_processor = self.build_log_processor(log_processor)
        # Since `get_instance` could return any subclass of ManagerMixin. The
        # corresponding attribute needs a type hint.
        self.logger = self.build_logger(log_level=log_level)

        # Collect and log environment information.
        self._log_env(env_cfg)

        # collect information of all modules registered in the registries
        registries_info = count_registered_modules(
            self.work_dir if self.rank == 0 else None, verbose=False)
        self.logger.debug(registries_info)

        # Build `message_hub` for communication among components.
        # `message_hub` can store log scalars (loss, learning rate) and
        # runtime information (iter and epoch). Those components that do not
        # have access to the runner can get iteration or epoch information
        # from `message_hub`. For example, models can get the latest created
        # `message_hub` by
        # `self.message_hub=MessageHub.get_current_instance()` and then get
        # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`.
        # See `MessageHub` and `ManagerMixin` for more details.
        self.message_hub = self.build_message_hub()
        # visualizer used for writing log or visualizing all kinds of data
        self.visualizer = self.build_visualizer(visualizer)
        if self.cfg:
            self.visualizer.add_config(self.cfg)

        self._load_from = load_from
        self._resume = resume
        # flag to mark whether checkpoint has been loaded or resumed
        self._has_loaded = False

        # build a model
        if isinstance(model, dict) and data_preprocessor is not None:
            # Merge the data_preprocessor to model config.
            model.setdefault('data_preprocessor', data_preprocessor)
        self.model = self.build_model(model)
        # wrap model
        self.model = self.wrap_model(
            self.cfg.get('model_wrapper_cfg'), self.model)

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._hooks: List[Hook] = []
        # register hooks to `self._hooks`
        self.register_hooks(default_hooks, custom_hooks)

        # dump `cfg` to `work_dir`
        self.dump_config()

3.3.1手动构建模块来使用执行器

3.3.1.1 手动构建模块进行训练

  • 使用执行器的某一项功能时需要准备好对应功能所依赖的模块。以使用执行器的训练功能为例,用户需要准备
    • 模型
    • 优化器
    • 参数调度器
    • 训练数据集
# 准备训练任务所需要的模块
import torch
from torch import nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from mmengine.model import BaseModel
from mmengine.optim.scheduler import MultiStepLR

# 定义一个多层感知机网络
class Network(BaseModel):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 10))
        self.loss = nn.CrossEntropyLoss()

    def forward(self, batch_inputs: torch.Tensor, data_samples = None, mode: str = 'tensor'):
        x = batch_inputs.flatten(1)
        x = self.mlp(x)
        if mode == 'loss':
            return {'loss': self.loss(x, data_samples)}
        elif mode == 'predict':
            return x.argmax(1)
        else:
            return x

model = Network()

# 构建优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 构建参数调度器用于调整学习率
lr_scheduler = MultiStepLR(optimizer, milestones=[2], by_epoch=True)
# 构建手写数字识别 (MNIST) 数据集
train_dataset = datasets.MNIST(root="MNIST", download=True, train=True, transform=transforms.ToTensor())
# 构建数据加载器
train_dataloader = DataLoader(dataset=train_dataset, batch_size=10, num_workers=2)
  • 在创建完符合上述文档规范的模块的对象后,就可以使用这些模块初始化执行器:
from mmengine.runner import Runner

# 训练相关参数设置,按迭代次数训练,训练9000次迭代
#train_cfg = dict(by_epoch=False, max_iters=9000)

# 训练相关参数设置,按轮次训练,训练3轮
train_cfg = dict(by_epoch=True, max_epochs=3)
# 初始化执行器
runner = Runner(model,
                work_dir='./train_mnist',  # 工作目录,用于保存模型和日志
                train_cfg=train_cfg,
                train_dataloader=train_dataloader,
                optim_wrapper=dict(optimizer=optimizer),
                param_scheduler=lr_scheduler)
# 执行训练
runner.train()
  • 上面的例子中,我们手动构建了一个多层感知机网络和手写数字识别 (MNIST) 数据集,以及训练所需要的优化器和学习率调度器,使用这些模块初始化了执行器,并且设置了训练配置 train_cfg,让执行器将模型训练3个轮次,最后通过调用执行器的 train 方法进行模型训练。

3.3.1.2 手动构建模块进行测试

  • 模型的测试需要用户准备:
    • 模型
    • 训练好的权重路径
    • 测试数据集
    • 评测器
from mmengine.evaluator import BaseMetric


class MnistAccuracy(BaseMetric):
    def process(self, data, preds) -> None:
        self.results.append(((data[1] == preds.cpu()).sum(), len(preds)))
    def compute_metrics(self, results):
        correct, batch_size = zip(*results)
        acc = sum(correct) / sum(batch_size)
        return dict(accuracy=acc)

model = Network()
test_dataset = datasets.MNIST(root="MNIST", download=True, train=False, transform=transforms.ToTensor())
test_dataloader = DataLoader(dataset=test_dataset)
metric = MnistAccuracy()
test_evaluator = Evaluator(metric)

# 初始化执行器
runner = Runner(model=model, 
                test_dataloader=test_dataloader, 
                test_evaluator=test_evaluator,
                load_from='./train_mnist/epoch_3.pth', 
                work_dir='./test_mnist')

# 执行测试
runner.test()
  • 这个例子中重新手动构建了一个多层感知机网络,以及测试用的手写数字识别数据集和使用 (Accuracy) 指标的评测器,并使用这些模块初始化执行器,最后通过调用执行器的 test 函数进行模型测试。

3.3.1.3 手动构建模块在训练过程中进行验证

  • 在模型训练过程中,通常会按一定的间隔在验证集上对模型进行验证。在使用 MMEngine 时,只需要构建训练和验证的模块,并在训练配置中设置验证间隔即可
# 准备训练任务所需要的模块
optimzier = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
lr_scheduler = MultiStepLR(milestones=[2], by_epoch=True)
train_dataset = datasets.MNIST(root="MNIST", download=True, train=True, transform=transforms.ToTensor())
train_dataloader = DataLoader(dataset=train_dataset, batch_size=10, num_workers=2)

# 准备验证需要的模块
val_dataset = datasets.MNIST(root="MNIST", download=True, train=False, transform=transforms.ToTensor())
val_dataloader = Dataloader(dataset=val_dataset)
metric = MnistAccuracy()
val_evaluator = Evaluator(metric)


# 训练相关参数设置
train_cfg = dict(by_epoch=True,  # 按轮次训练
                 max_epochs=5,  # 训练5轮
                 val_begin=2,  # 从第 2 个 epoch 开始验证
                 val_interval=1)  # 每隔1轮进行1次验证

# 初始化执行器
runner = Runner(model=model, 
                optim_wrapper=dict(optimizer=optimzier), 
                param_scheduler=lr_scheduler,
                train_dataloader=train_dataloader, 
                val_dataloader=val_dataloader, 
                val_evaluator=val_evaluator,
                train_cfg=train_cfg, 
                work_dir='./train_val_mnist')
# 执行训练
runner.train()

3.3.2 通过配置文件使用执行器

  • OpenMMLab 的开源项目普遍使用注册器 + 配置文件的方式来管理和构建模块
  • MMEngine 中的执行器也推荐使用配置文件进行构建。 下面是一个通过配置文件使用执行器的例子:
from mmengine import Config
from mmengine.runner import Runner

# 加载配置文件
config = Config.fromfile('configs/resnet/resnet50_8xb32_in1k.py')

# 通过配置文件初始化执行器
runner = Runner.build_from_cfg(config)
# 执行训练
runner.train()
# 执行测试
runner.test()
  • 与手动构建模块来使用执行器不同的是,通过调用 Runner 类的 build_from_cfg 方法,执行器能够自动读取配置文件中的模块配置,从相应的注册器中构建所需要的模块,用户不再需要考虑训练和测试分别依赖哪些模块,也不需要为了切换训练的模型和数据而大量改动代码。
  • 下面是一个典型的使用配置文件调用 MMClassification 中的模块训练分类器的简单例子:
# 工作目录,保存权重和日志
work_dir = './train_resnet'
# 默认注册器域
default_scope = 'mmcls'  # 默认使用 `mmcls` (MMClassification) 注册器中的模块
# 模型配置
model = dict(type='ImageClassifier',
             backbone=dict(type='ResNet', depth=50),
             neck=dict(type='GlobalAveragePooling'),
             head=dict(type='LinearClsHead',num_classes=1000))
# 数据配置
train_dataloader = dict(dataset=dict(type='ImageNet', pipeline=[...]),
                        sampler=dict(type='DefaultSampler', shuffle=True),
                        batch_size=32,
                        num_workers=4)
val_dataloader = ...
test_dataloader = ...

# 优化器配置
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))
# 参数调度器配置
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1)
#验证和测试的评测器配置
val_evaluator = dict(type='Accuracy')
test_evaluator = dict(type='Accuracy')

# 训练、验证、测试流程配置
train_cfg = dict(
    by_epoch=True,
    max_epochs=100,
    val_begin=20,  # 从第 20 个 epoch 开始验证
    val_interval=1  # 每隔一个 epoch 进行一次验证
)
val_cfg = dict()
test_cfg = dict()

# 自定义钩子 (可选)
custom_hooks = [...]

# 默认钩子 (可选,未在配置文件中写明时将使用默认配置)
default_hooks = dict(
    runtime_info=dict(type='RuntimeInfoHook'),  # 运行时信息钩子
    timer=dict(type='IterTimerHook'),  # 计时器钩子
    sampler_seed=dict(type='DistSamplerSeedHook'),  # 为每轮次的数据采样设置随机种子的钩子
    logger=dict(type='TextLoggerHook'),  # 训练日志钩子
    param_scheduler=dict(type='ParamSchedulerHook'),  # 参数调度器执行钩子
    checkpoint=dict(type='CheckpointHook', interval=1),  # 模型保存钩子
)

# 环境配置 (可选,未在配置文件中写明时将使用默认配置)
env_cfg = dict(
    cudnn_benchmark=False,  # 是否使用 cudnn_benchmark
    dist_cfg=dict(backend='nccl'),  # 分布式通信后端
    mp_cfg=dict(mp_start_method='fork')  # 多进程设置
)
# 日志处理器 (可选,未在配置文件中写明时将使用默认配置)
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
# 日志等级配置
log_level = 'INFO'

# 加载权重的路径 (None 表示不加载)
load_from = None
# 从加载的权重文件中恢复训练
resume = False
  • 一个完整的配置文件主要由模型、数据、优化器、参数调度器、评测器等模块的配置,训练、验证、测试等流程的配置,还有执行流程过程中的各种钩子模块的配置,以及环境和日志等其他配置的字段组成。 通过配置文件构建的执行器采用了懒初始化 (lazy initialization),只有当调用到训练或测试等执行函数时,才会根据配置文件去完整初始化所需要的模块

3.3.3 加载权重或恢复训练

  • 执行器可以通过 load_from 参数加载检查点(checkpoint)文件中的模型权重,只需要将 load_from 参数设置为检查点文件的路径即可。
runner = Runner(model=model, 
                test_dataloader=test_dataloader, 
                test_evaluator=test_evaluator,
                load_from='./resnet50.pth')
  • 如果是通过配置文件使用执行器,只需修改配置文件中的 load_from 字段即可。
  • 用户也可通过设置 resume=True 来加载检查点中的训练状态信息来恢复训练。当 load_from 和 resume=True 同时被设置时,执行器将加载 load_from 路径对应的检查点文件中的训练状态。
  • 如果仅设置 resume=True,执行器将会尝试从 work_dir 文件夹中寻找并读取最新的检查点文件

3.4 钩子(Hook)

  • 钩子编程是一种编程模式,是指在程序的一个或者多个位置设置位点(挂载点),当程序运行至某个位点时,会自动调用运行时注册到位点的所有方法。钩子编程可以提高程序的灵活性和拓展性,用户将自定义的方法注册到位点便可被调用而无需修改程序中的代码。

3.4.1 内置钩子

  • MMEngine 提供了很多内置的钩子,将内置钩子分为两类,分别是默认钩子以及自定义钩子,前者表示会默认往执行器注册,后者表示需要用户自己注册。
  • 每个钩子都有对应的优先级,在同一位点,钩子的优先级越高,越早被执行器调用,如果优先级一样,被调用的顺序和钩子注册的顺序一致。优先级列表如下:
    • HIGHEST (0)
    • VERY_HIGH (10)
    • HIGH (30)
    • ABOVE_NORMAL (40)
    • NORMAL (50)
    • BELOW_NORMAL (60)
    • LOW (70)
    • VERY_LOW (90)
    • LOWEST (100)
  • 默认钩子
名称用途优先级
RuntimeInfoHook往 message hub 更新运行时信息VERY_HIGH (10)
IterTimerHook统计迭代耗时NORMAL (50)
DistSamplerSeedHook确保分布式 Sampler 的 shuffle 生效NORMAL (50)
LoggerHook打印日志BELOW_NORMAL (60)
ParamSchedulerHook调用 ParamScheduler 的 step 方法LOW (70)
CheckpointHook按指定间隔保存权重VERY_LOW (90)
  • 自定义钩子
名称用途优先级
EMAHook模型参数指数滑动平均NORMAL (50)
EmptyCacheHookPyTorch CUDA 缓存清理NORMAL (50)
SyncBuffersHook同步模型的 bufferNORMAL (50)
NaiveVisualizationHook可视化LOWEST (100)
  • 不建议修改默认钩子的优先级,因为优先级低的钩子可能会依赖优先级高的钩子。例如 CheckpointHook 的优先级需要比 ParamSchedulerHook 低,这样保存的优化器状态才是正确的状态。另外,自定义钩子的优先级默认为 NORMAL (50)。
  • 两种钩子在执行器中的设置不同,默认钩子的配置传给执行器的 default_hooks 参数,自定义钩子的配置传给 custom_hooks 参数,如下所示:
from mmengine.runner import Runner

default_hooks = dict(
    runtime_info=dict(type='RuntimeInfoHook'),
    timer=dict(type='IterTimerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    logger=dict(type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
)

custom_hooks = [
    dict(type='NaiveVisualizationHook', priority='LOWEST'),
]

runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)
runner.train()

3.4.1.1 CheckpointHook

  • CheckpointHook 按照给定间隔保存模型的权重,如果是分布式多卡训练,则只有主(master)进程会保存权重。CheckpointHook 的主要功能如下:
    1. 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重
# by_epoch 的默认值为 True, 每隔 5 个 epoch 保存一次权重
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))

# 以迭代次数作为保存间隔,则可以将 by_epoch 设为 False,
# interval=5 则表示每迭代 5 次保存一次权重
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))
    1. 保存最新的多个权重
# 假如一共训练 20 个 epoch,那么会在第 5, 10, 15, 20 个 epoch 保存模型,
# 但是在第 15 个 epoch 的时候会删除第 5 个 epoch 保存的权重,在第 20 个 epoch 的时候
# 会删除第 10 个 epoch 的权重,最终只有第 15 和第 20 个 epoch 的权重才会被保存。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=2))
    1. 保存最优权重
#  如果想要保存训练过程验证集的最优权重,可以设置 save_best 参数,如果设置为 'auto',
# 则会根据验证集的第一个评价指标(验证集返回的评价指标是一个有序字典)判断当前权重是否最优。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', save_best='auto'))
  • 保存最优权重参数规则:
    • 也可以直接指定 save_best 的值为评价指标,例如在分类任务中,可以指定为 save_best=‘top-1’,则会根据 ‘top-1’ 的值判断当前权重是否最优。
    • 除了 save_best 参数,和保存最优权重相关的参数还有 rule,greater_keys 和 less_keys,这三者用来判断 save_best 的值是越大越好还是越小越好。例如指定了 save_best=‘top-1’,可以指定 rule=‘greater’,则表示该值越大表示权重越好。
    1. 指定保存权重的路径
    • 权重默认保存在工作目录(work_dir),但可以通过设置 out_dir 改变保存路径。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))

3.4.1.2 LoggerHook

  • LoggerHook 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
  • 如果我们希望每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数,配置如下:
default_hooks = dict(logger=dict(type='LoggerHook', interval=20))

3.4.2 自定义钩子

  • 如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
  • 例如,如果希望在训练的过程中判断损失值是否有效,如果值为无穷大则无效,我们可以在每次迭代后判断损失值是否无穷大,因此只需重写 after_train_iter 位点。
import torch
from mmengine.registry import HOOKS
from mmengine.hooks import Hook

@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    """Check invalid loss hook.

    This hook will regularly check whether the loss is valid
    during training.

    Args:
        interval (int): Checking interval (every k iterations).
            Defaults to 50.
    """

    def __init__(self, interval=50):
        self.interval = interval

    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """All subclasses should override this method, if they need any
        operations after each training iteration.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (dict, optional): Outputs from model.
        """
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']),\
                runner.logger.info('loss become infinite or NaN!')
  • 我们只需将钩子的配置传给执行器的 custom_hooks 的参数,执行器初始化的时候会注册钩子
from mmengine.runner import Runner

custom_hooks = dict(
    dict(type='CheckInvalidLossHook', interval=50)
)
runner = Runner(custom_hooks=custom_hooks, ...)  # 实例化执行器,主要完成环境的初始化以及各种模块的构建
runner.train()  # 执行器开始训练

3.5 模型(Model)

  • 在训练深度学习任务时,我们通常需要定义一个模型来实现算法的主体。在基于 MMEngine 开发时,模型由执行器管理,需要实现 train_step,val_step 和 test_step 方法。
  • 对于检测、识别、分割一类的深度学习任务,上述方法通常为标准的流程,例如在 train_step 里更新参数,返回损失;val_step 和 test_step 返回预测结果。因此 MMEngine 抽象出模型基类 BaseModel,实现了上述接口的标准流程。我们只需要让模型继承自模型基类,并按照一定的规范实现 forward,就能让模型在执行器中运行起来。
  • 模型基类继承自模块基类,能够通过配置 init_cfg 灵活的选择初始化方式。

3.5.1 接口约定

3.5.1.1 forward

  • forward: forward 的入参需要和 DataLoader 的输出保持一致 (自定义数据预处理器除外),如果 DataLoader 返回元组类型的数据 data,forward 需要能够接受 *data 的解包后的参数;如果返回字典类型的数据 data,forward 需要能够接受 **data 解包后的参数。 mode 参数用于控制 forward 的返回结果:
    • mode=‘loss’:loss 模式通常在训练阶段启用,并返回一个损失字典。损失字典的 key-value 分别为损失名和可微的 torch.Tensor。字典中记录的损失会被用于更新参数和记录日志。模型基类会在 train_step 方法中调用该模式的 forward。
    • mode=‘predict’: predict 模式通常在验证、测试阶段启用,并返回列表/元组形式的预测结果,预测结果需要和 process 接口的参数相匹配。OpenMMLab 系列算法对 predict 模式的输出有着更加严格的约定,需要输出列表形式的数据元素。模型基类会在 val_step,test_step 方法中调用该模式的 forward。
    • mode=‘tensor’:tensor 和 predict 模式均返回模型的前向推理结果,区别在于 tensor 模式下,forward 会返回未经后处理的张量,例如返回未经非极大值抑制(nms)处理的检测结果,返回未经 argmax 处理的分类结果。我们可以基于 tensor 模式的结果进行自定义的后处理。

3.5.1.2 train_step

  • train_step: 调用 loss 模式的 forward 接口,得到损失字典。模型基类基于优化器封装 实现了标准的梯度计算、参数更新、梯度清零流程。

3.5.1.3 val_step

  • val_step: 调用 predict 模式的 forward,返回预测结果,预测结果会被进一步传给评测器的 process 接口和钩子(Hook)的 after_val_iter 接口。

3.5.1.4 test_step

  • test_step: 同 val_step,预测结果会被进一步传给 after_test_iter 接口。

3.5.1.5 实例

  • 基于上述接口约定,我们定义了继承自模型基类的 NeuralNetwork,配合执行器来训练 FashionMNIST:
from torch.utils.data import DataLoader
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from mmengine.model import BaseModel
from mmengine.evaluator import BaseMetric
from mmengine.runner import Runner


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(dataset=training_data, batch_size=64)
test_dataloader = DataLoader(dataset=test_data, batch_size=64)


class NeuralNetwork(BaseModel):
    def __init__(self, data_preprocessor=None):
        super(NeuralNetwork, self).__init__(data_preprocessor)
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
        self.loss = nn.CrossEntropyLoss()

    def forward(self, img, label, mode='tensor'):
        x = self.flatten(img)
        pred = self.linear_relu_stack(x)
        loss = self.loss(pred, label)
        if mode == 'loss':
            return dict(loss=loss)
        elif mode=='predict':
            return pred.argmax(1), loss.item()
        else:
            return pred


class FashionMnistMetric(BaseMetric):
    def process(self, data, preds) -> None:
        # data 参数为 Dataloader 返回的元组,即 (img, label)
        # predict 为模型 `predict` 模式下,返回的元组,分别为 `pred.argmax(1) 和 `loss``
        self.results.append(((data[1] == preds[0].cpu()).sum(), preds[1], len(preds[0])))

    def compute_metrics(self, results):
        correct, loss, batch_size = zip(*results)
        test_loss, correct = sum(loss) / len(self.results), sum(correct) / sum(batch_size)
        return dict(Accuracy=correct, Avg_loss=test_loss)


runner = Runner(
    model=NeuralNetwork(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=1e-3)),
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_cfg=dict(fp16=True),
    val_dataloader=test_dataloader,
    val_evaluator=dict(metrics=FashionMnistMetric()))
runner.train()
  • 在本例中,NeuralNetwork.forward 存在着以下跨模块的接口约定:
    • 由于 train_dataloader 会返回一个 (img, label) 形式的元组,因此 forward 接口的前两个参数分别需要为 img 和 label。
    • 由于 forward 在 predict 模式下会返回 (pred, loss) 形式的元组,因此 process 的 preds 参数应当同样为相同形式的元组。

3.5.2 数据预处理器(DataPreprocessor)

  • 如果你的电脑配有 GPU(或其他能够加速训练的硬件,如 mps、ipu 等),并运行了上节的代码示例。你会发现 Pytorch 的示例是在 CPU 上运行的,而 MMEngine 的示例是在 GPU 上运行的。MMEngine 是在何时把数据和模型从 CPU 搬运到 GPU 的呢?
  • 事实上,执行器会在构造阶段将模型搬运到指定设备,而数据则会在 train_step、val_step、test_step 中,被基础数据预处理器(BaseDataPreprocessor)搬运到指定设备,进一步将处理好的数据传给模型。数据预处理器作为模型基类的一个属性,会在模型基类的构造过程中被实例化。
  • 为了体现数据预处理器起到的作用,我们仍然以上一节训练 FashionMNIST 为例, 实现了一个简易的数据预处理器,用于搬运数据和归一化:
from torch.optim import SGD
from mmengine.model import BaseDataPreprocessor, BaseModel


class NeuralNetwork1(NeuralNetwork):

    def __init__(self, data_preprocessor):
        super().__init__(data_preprocessor=data_preprocessor)
        self.data_preprocessor = data_preprocessor

    def train_step(self, data, optimizer):
        img, label = self.data_preprocessor(data)
        loss = self(img, label, mode='loss')['loss'].sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        return dict(loss=loss)

    def test_step(self, data):
        img, label = self.data_preprocessor(data)
        return self(img, label, mode='predict')

    def val_step(self, data):
        img, label = self.data_preprocessor(data)
        return self(img, label, mode='predict')


class NormalizeDataPreprocessor(BaseDataPreprocessor):

    def forward(self, data, training=False):
        img, label = [item for item in data]
        img = (img - 127.5) / 127.5
        return img, label


model = NeuralNetwork1(data_preprocessor=NormalizeDataPreprocessor())
optimizer = SGD(model.parameters(), lr=0.01)
data = (torch.full((3, 28, 28), fill_value=127.5), torch.ones(3, 10))

model.train_step(data, optimizer)
model.val_step(data)
model.test_step(data)
  • 上例中,我们实现了 BaseModel.train_step、BaseModel.val_step 和 BaseModel.test_step 的简化版。数据经 NormalizeDataPreprocessor.forward 归一化处理,解包后传给 NeuralNetwork.forward,进一步返回损失或者预测结果。如果想实现自定义的参数优化或预测逻辑,可以自行实现 train_step、val_step 和 test_step,具体例子可以参考:使用 MMEngine 训练生成对抗网络

3.6 模型精度评测(Evaluation)

  • 在模型验证和模型测试中,通常需要对模型精度做定量评测。在 MMEngine 中实现了评测指标(Metric)和评测器(Evaluator)模块来完成这一功能:
    • 评测指标: 用于根据测试数据和模型预测结果,完成模型特定精度指标的计算。在 OpenMMLab 各算法库中提供了对应任务的常用评测指标,如 MMClassification 中提供了分类正确率指标(Accuracy) 用于计算分类模型的 Top-k 分类正确率。
    • 评测器: 是评测指标的上层模块,用于在数据输入评测指标前完成必要的格式转换,并提供分布式支持。在模型训练和测试中,评测器由执行器(Runner)自动构建。用户亦可根据需求手动创建评测器,进行离线评测。

3.6.1 在模型训练或测试中进行评测

3.6.1.1 评测指标配置

  • 在基于 MMEngine 进行模型训练或测试时,执行器会自动构建评测器进行评测,用户只需要在配置文件中通过 val_evaluator 和 test_evaluator 2 个字段分别指定模型验证和测试阶段的评测指标即可。例如,用户在使用 MMClassification 训练分类模型时,希望在模型验证阶段评测 top-1 和 top-5 分类正确率,可以按以下方式配置:
val_evaluator = dict(type='Accuracy', top_k=(1, 5))  # 使用分类正确率评测指标
  • 如果需要同时评测多个指标,也可以将 val_evaluator 或 test_evaluator 设置为一个列表,其中每一项为一个评测指标的配置信息。例如,在使用 MMDetection 训练全景分割模型时,希望在模型测试阶段同时评测模型的目标检测(COCO AP/AR)和全景分割精度,可以按以下方式配置:
test_evaluator = [
    # 目标检测指标
    dict(
        type='COCOMetric',
        metric=['bbox', 'segm'],
        ann_file='annotations/instances_val2017.json',
    ),
    # 全景分割指标
    dict(
        type='CocoPanopticMetric',
        ann_file='annotations/panoptic_val2017.json',
        seg_prefix='annotations/panoptic_val2017',
    )
]

3.6.1.2 自定义评测指标

  • 如果算法库中提供的常用评测指标无法满足需求,用户也可以增加自定义的评测指标。具体的方法可以参考评测指标和评测器设计

3.6.2 使用离线结果进行评测

  • 另一种常见的模型评测方式,是利用提前保存在文件中的模型预测结果进行离线评测。此时,由于不存在执行器,用户需要手动构建评测器,并调用评测器的相应接口完成评测。以下是一个离线评测示例:
from mmengine.evaluator import Evaluator
from mmengine.fileio import load

# 构建评测器。参数 `metrics` 为评测指标配置
evaluator = Evaluator(metrics=dict(type='Accuracy', top_k=(1, 5)))

# 从文件中读取测试数据。数据格式需要参考具使用的 metric。
data = load('test_data.pkl')

# 从文件中读取模型预测结果。该结果由待评测算法在测试数据集上推理得到。
# 数据格式需要参考具使用的 metric。
predictions = load('prediction.pkl')

# 调用评测器离线评测接口,得到评测结果
# chunk_size 表示每次处理的样本数量,可根据内存大小调整
results = evaluator.offline_evaluate(data, predictions, chunk_size=128)

3.7 优化器封装(OptimWrapper)

  • MMEngine 实现了优化器封装,为用户提供了统一的优化器访问接口。优化器封装支持不同的训练策略,包括混合精度训练、梯度累加和梯度截断。用户可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程,用户可以基于这一套流程,实现同一套代码,不同训练策略的切换。

3.7.1 优化器封装 vs 优化器

  • 分别基于 Pytorch 内置的优化器和 MMEngine 的优化器封装进行单精度训练、混合精度训练和梯度累加,对比二者实现上的区别。

3.7.1.1 基于 Pytorch 的 SGD 优化器实现单精度训练

import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as F

inputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

3.7.1.2 使用 MMEngine 的优化器封装实现单精度训练

from mmengine.optim import OptimWrapper

optim_wrapper = OptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    optim_wrapper.update_params(loss)

在这里插入图片描述

  • 优化器封装的 update_params 实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。

3.7.1.3 基于 Pytorch 的 SGD 优化器实现混合精度训练

from torch.cuda.amp import autocast

model = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10

for input, target in zip(inputs, targets):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

3.7.1.4 基于 MMEngine 的 优化器封装实现混合精度训练

from mmengine.optim import AmpOptimWrapper

optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述

  • 混合精度训练需要使用 AmpOptimWrapper,他的 optim_context 接口类似 autocast,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍

3.7.1.5 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加

for idx, (input, target) in enumerate(zip(inputs, targets)):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    if idx % 2 == 0:
        optimizer.step()
        optimizer.zero_grad()

3.7.1.6 基于 MMEngine 的优化器封装实现混合精度训练和梯度累加

optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述

  • 只需要配置 accumulative_counts 参数,并调用 update_params 接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 optim_wrapper 上下文,可以避免梯度累加阶段不必要的梯度同步。
  • 优化器封装同样提供了更细粒度的接口,方便用户实现一些自定义的参数更新逻辑:
    • backward:传入损失,用于计算参数梯度
    • step: 同 optimizer.step,用于更新参数
    • zero_grad: 同 optimizer.zero_grad,用于参数的梯度清0
  • 可以使用上述接口实现和 Pytorch 优化器相同的参数更新逻辑:
for idx, (input, target) in enumerate(zip(inputs, targets)):
    optimizer.zero_grad()
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.backward(loss)  # 计算参数梯度
    if idx % 2 == 0:
        optim_wrapper.step() # 更新参数
        optim_wrapper.zero_grad() # 参数的梯度清0
  • 为优化器封装配置梯度裁减策略:
# 基于 torch.nn.utils.clip_grad_norm_ 对梯度进行裁减
optim_wrapper = AmpOptimWrapper(
    optimizer=optimizer, clip_grad=dict(max_norm=1))

# 基于 torch.nn.utils.clip_grad_value_ 对梯度进行裁减
optim_wrapper = AmpOptimWrapper(
    optimizer=optimizer, clip_grad=dict(clip_value=0.2))

3.7.1.7 获取学习率/动量

  • 优化器封装提供了 get_lr 和 get_momentum 接口用于获取优化器的一个参数组的学习率
import torch.nn as nn
from torch.optim import SGD

from mmengine.optim import OptimWrapper

model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer)

print(optimizer.param_groups[0]['lr'])  # -1.01
print(optimizer.param_groups[0]['momentum'])  # 0
print(optim_wrapper.get_lr())  # {'lr': [0.01]}
print(optim_wrapper.get_momentum())  # {'momentum': [0]}
# 输出结果
#0.01
#0
#{'lr': [0.01]}
#{'momentum': [0]}

3.7.1.8 导出/加载状态字典

  • 优化器封装和优化器一样,提供了 state_dict 和 load_state_dict 接口,用于导出/加载优化器状态,对于 AmpOptimWrapper,优化器封装还会额外导出混合精度训练相关的参数:

3.8 优化器参数调整策略(Parameter Scheduler)

3.9 数据变换 (Data Transform)

4. 高级模块

4.1 数据集基类(BaseDataset)

4.2 抽象数据接口

4.3 可视化

4.4 初始化

4.5 分布式通信原语

4.6 记录日志

4.7 文件读写

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-10-31 11:56:46  更:2022-10-31 11:59:43 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 2:18:21-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计