前言
?本文主要介绍detectron2的engine目录下内容,该目录主要介绍了训练一个模型的思路。
1、Hook的创建
? 上来讲hook可能不太友好,但hook是理解detectron2训练流程的一个关键。首先这是hook的父类。没什么可说的,关键是实现了四个方法。即hook的发生作用的时间节点发生在“训练前,训练后,iter前和iter后”。
class HookBase:
def before_train(self):
"""
Called before the first iteration.
"""
pass
def after_train(self):
"""
Called after the last iteration.
"""
pass
def before_step(self):
"""
Called before each iteration.
"""
pass
def after_step(self):
"""
Called after each iteration.
"""
pass
?在看hooks.py文件中,该文件下定义了8个hook类。
__all__ = [
"CallbackHook",
"IterationTimer",
"PeriodicWriter",
"PeriodicCheckpointer",
"LRScheduler",
"AutogradProfiler",
"EvalHook",
"PreciseBN",
]
? 这八个hook分别发生在训练中,比如在 before_train阶段,会调用CallbackHook、IterationTimer等。这里我贴几个hook的简略代码。
class CallbackHook(HookBase):
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
"""
Each argument is a function that takes one argument: the trainer.
"""
self._before_train = before_train
self._before_step = before_step
self._after_step = after_step
self._after_train = after_train
def before_train(self):
if self._before_train:
self._before_train(self.trainer)
def after_train(self):
if self._after_train:
self._after_train(self.trainer)
del self._before_train, self._after_train
del self._before_step, self._after_step
def before_step(self):
if self._before_step:
self._before_step(self.trainer)
def after_step(self):
if self._after_step:
self._after_step(self.trainer)
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
def before_train(self):
self.max_iter = self.trainer.max_iter
def after_step(self):
self.step(self.trainer.iter)
?我这里截取的是源代码。若hook实现了几个方法,则就会在训练到某个阶段时就会调用这个hook。举个例子,上述代码中PeriodicCheckpointer只实现了before_train和after_step方法,故该hook仅在这两个阶段调用。其余方法before_step则直接pass。 ? 在需要注意的一个点就是 hook类方法的参数。self.trainer。此处留意下即可。
2、train_loop.py
2.1 TrainerBase类
?第一节我们知道了在训练的各个阶段都有对应的hook。但是光有hook不够的,我们还需要一个类来合理的执行这些hook。便是TrainBase类。这里粘贴下源码:
class TrainerBase:
def __init__(self):
self._hooks = []
def register_hooks(self, hooks):
"""
Register hooks to the trainer. The hooks are executed in the order
they are registered.
Args:
hooks (list[Optional[HookBase]]): list of hooks
"""
hooks = [h for h in hooks if h is not None]
self._hooks.extend(hooks)
def train(self, start_iter: int, max_iter: int):
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
self.iter += 1
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
self.storage.iter = self.iter
for h in self._hooks:
h.after_train()
def before_step(self):
self.storage.iter = self.iter
for h in self._hooks:
h.before_step()
def after_step(self):
for h in self._hooks:
h.after_step()
def run_step(self):
raise NotImplementedError
?从上述可以看出:主要实现了两个方法:一个register_hooks,用一个list来extend的hook。另一个就是train: 内部安排了代码的执行逻辑: before_train --> before_step–>run_step–>after_step–>after_train。当执行到对应阶段时,便去hooks列表里面调用对应hook对象。 ?而方法run_step是为了后续继承该类来特殊实现的。
2.2 SimpleTrainer类
? 该类继承自TrainerBase类,是更加高层点儿的类。主要就是实现了上述所说的run_step方法,即前向传播,计算损失,反传梯度,优化器更新。
class SimpleTrainer(TrainerBase):
def __init__(self, model, data_loader, optimizer):
"""
Args:
model: a torch Module. Takes a data from data_loader and returns a
dict of losses.
data_loader: an iterable. Contains data to be used to call model.
optimizer: a torch optimizer.
"""
super().__init__()
"""
We set the model to training mode in the trainer.
However it's valid to train a model that's in eval mode.
If you want your model (or a submodule of it) to behave
like evaluation during training, you can overwrite its train() method.
"""
model.train()
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
def run_step(self):
"""
Implement the standard training logic described above.
"""
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
"""
If you want to do something with the data, you can wrap the dataloader.
"""
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
If you want to do something with the losses, you can wrap the model.
"""
loss_dict = self.model(data)
losses = sum(loss_dict.values())
"""
If you need to accumulate gradients or do something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
losses.backward()
self._write_metrics(loss_dict, data_time)
"""
If you need gradient clipping/scaling or other processing, you can
wrap the optimizer with your custom `step()` method. But it is
suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
"""
self.optimizer.step()
?总的来说,SimpleTrainer类为我们实现了一个最简单的训练器。实际运用中,一般继承自TrainBase类然后实现自己run_step方法更加灵活。下面介绍detectron2中更加通用的一种训练器。
3、defaults.py
?大多数模型均使用的是本文件下的DefaultTrainer类。贴下源码。
class DefaultTrainer(TrainerBase):
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
super().__init__()
logger = logging.getLogger("detectron2")
if not logger.isEnabledFor(logging.INFO):
setup_logger()
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
if comm.get_world_size() > 1:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
self.checkpointer = DetectionCheckpointer(
model,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
?我这里之贴出了初始化部分。因为代码内容太多了,所以后续我会逐个详细介绍该类中是如何构建模型/优化器以及如何调用train方法的。
总结
?终于开了个头了。
|