前言
?该篇介绍mmdet的损失函数部分,后续会逐渐扩充mmdet中损失函数的使用注意事项以及使用方法。
1、mmdet中损失函数模块简介
1.1. Loss的注册器
?先来看段代码:mmdet/models/builder.py
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
这里MODELS注册器同时赋予给了其他模块,为何操作后续会在
1.2. 注册L1 Loss()
@LOSSES.register_module()
class L1Loss(nn.Module):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(L1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): 预测框. 比如[N];
target (torch.Tensor): 真实值.比如[N];
weight (torch.Tensor, optional): 每个样本的权重,shape = [N], Defaults to None.
avg_factor (int, optional): 控制总损失的系数,作用跟loss_weight重了。Defaults to None.
reduction_override (str, optional): 作用跟reduction重了. Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
?上述初始化参数比较简单,就两个参数:reduction默认是’mean’,即返回损失的均值,loss_weight控制L1 Loss总的权重值。但在forward部分参数就多了:pred和target不必多说,二者shape应该一致,假设在处理bbox二者shape为[1000,4];weight的shape应该和pred的shape一样,控制每个样本对总的损失的权重值;avg_factor和reduction_override用的不多,这两个参数分别和loss_weight和reduction参数撞了,不用管。 ?理解了上述各个参数作用,举个实际例子算一下:
import torch
from mmdet.models import build_loss
loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]])
loss = obj(pred, target)
print(loss, 9/8)
?发现跟实际手算结果一致,简单说下计算流程:通过torch.abs计算每个元素之间的绝对值,然后.mean()方法得到最终的结果,这里除以的是所有元素的个数。比如此处就是2*4=8。 ?在举个带weight的版本的:
import torch
from mmdet.models import build_loss
loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]])
weight = torch.Tensor([[1,1,1,1],[1,1,1,0]])
loss = obj(pred, target, weight)
print(loss, 8/8)
1.3. 内部实现逻辑
?本质上使用的装饰器实现loss的封装,简单说下调用的流程: 1)调用forward方法,内部调用了 l1_loss函数;
@weighted_loss
def l1_loss(pred, target):
"""L1 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
Returns:
torch.Tensor: Calculated loss
"""
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
loss = torch.abs(pred - target)
return loss
2)此时碰见 @weighted_loss装饰器,则先跳入装饰器, 注意此时首先不计算l1 loss函数, mmdet/losses/losses/utils.py
def weighted_loss(loss_func):
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
?首先对loss_func即l1_loss进行了一次包装,即往里面多塞了一些参数**kwargs,然后此时执行l1_loss,得到各个元素之间的loss值。 3)最后一步,执行weight_reduce_loss来得到损失的最终形式(weight, reduction, avg_factor):
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
@mmcv.jit(derivate=True, coderize=True)
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
if weight is not None:
loss = loss * weight
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
if reduction == 'mean':
loss = loss.sum() / avg_factor
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
1.4. 总结
?基本上mmdet所有损失的计算流程就上述过程,在使用L1 Loss时,不必关心那么多超参,直接build loss然后传入pred和target即可,其余参数基本默认即可。
总结
?未完待续…
|