IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> mmdet之Loss模块详解 -> 正文阅读

[人工智能]mmdet之Loss模块详解


前言

?该篇介绍mmdet的损失函数部分,后续会逐渐扩充mmdet中损失函数的使用注意事项以及使用方法。


1、mmdet中损失函数模块简介

1.1. Loss的注册器

?先来看段代码:mmdet/models/builder.py

from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry

MODELS = Registry('models', parent=MMCV_MODELS) # 此处多了一个parent参数,暂时不予考虑

BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS         # Loss 注册器
DETECTORS = MODELS

这里MODELS注册器同时赋予给了其他模块,为何操作后续会在

1.2. 注册L1 Loss()

@LOSSES.register_module()
class L1Loss(nn.Module):
    """L1 loss.

    Args:
        reduction (str, optional): The method to reduce the loss.
            Options are "none", "mean" and "sum".
        loss_weight (float, optional): The weight of loss.
    """

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(L1Loss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.

        Args:
            pred (torch.Tensor): 预测框. 比如[N];
            target (torch.Tensor): 真实值.比如[N];
            weight (torch.Tensor, optional): 每个样本的权重,shape = [N], Defaults to None.
            avg_factor (int, optional): 控制总损失的系数,作用跟loss_weight重了。Defaults to None.
            reduction_override (str, optional): 作用跟reduction重了. Defaults to None.
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * l1_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

?上述初始化参数比较简单,就两个参数:reduction默认是’mean’,即返回损失的均值,loss_weight控制L1 Loss总的权重值。但在forward部分参数就多了:pred和target不必多说,二者shape应该一致,假设在处理bbox二者shape为[1000,4];weight的shape应该和pred的shape一样,控制每个样本对总的损失的权重值;avg_factor和reduction_override用的不多,这两个参数分别和loss_weight和reduction参数撞了,不用管。
?理解了上述各个参数作用,举个实际例子算一下:

import torch
from mmdet.models import build_loss

loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)

# 模块计算
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])   # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
loss = obj(pred, target)
print(loss, 9/8)

?发现跟实际手算结果一致,简单说下计算流程:通过torch.abs计算每个元素之间的绝对值,然后.mean()方法得到最终的结果,这里除以的是所有元素的个数。比如此处就是2*4=8。
?在举个带weight的版本的:

import torch
from mmdet.models import build_loss

loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)

# 模块计算
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])   # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
# 带weight版本的: 最后一个元素的weight =0
weight = torch.Tensor([[1,1,1,1],[1,1,1,0]])     # [2,4]
loss = obj(pred, target, weight)
print(loss, 8/8)

1.3. 内部实现逻辑

?本质上使用的装饰器实现loss的封装,简单说下调用的流程:
1)调用forward方法,内部调用了 l1_loss函数;

@weighted_loss
def l1_loss(pred, target):
    """L1 loss.

    Args:
        pred (torch.Tensor): The prediction.
        target (torch.Tensor): The learning target of the prediction.

    Returns:
        torch.Tensor: Calculated loss
    """
    if target.numel() == 0:
        return pred.sum() * 0

    assert pred.size() == target.size()
    loss = torch.abs(pred - target)  # 对应元素相减
    return loss

2)此时碰见 @weighted_loss装饰器,则先跳入装饰器, 注意此时首先不计算l1 loss函数, mmdet/losses/losses/utils.py

def weighted_loss(loss_func):
    @functools.wraps(loss_func)
    def wrapper(pred,
                target,
                weight=None,
                reduction='mean',
                avg_factor=None,
                **kwargs):
        # 获取每个元素之间损失
        loss = loss_func(pred, target, **kwargs) 
        loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
        return loss

    return wrapper

?首先对loss_func即l1_loss进行了一次包装,即往里面多塞了一些参数**kwargs,然后此时执行l1_loss,得到各个元素之间的loss值。
3)最后一步,执行weight_reduce_loss来得到损失的最终形式(weight, reduction, avg_factor):

def reduce_loss(loss, reduction):
    """Reduce loss as specified.

    Args:
        loss (Tensor): Elementwise loss tensor.
        reduction (str): Options are "none", "mean" and "sum".

    Return:
        Tensor: Reduced loss tensor.
    """
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, elementwise_mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    elif reduction_enum == 2:
        return loss.sum()

@mmcv.jit(derivate=True, coderize=True)
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):

    # if weight is specified, apply element-wise weight
    if weight is not None:
        loss = loss * weight

    # if avg_factor is not specified, just reduce the loss
    if avg_factor is None:
        loss = reduce_loss(loss, reduction)
    else:
        # if reduction is mean, then average the loss by avg_factor
        if reduction == 'mean':
            loss = loss.sum() / avg_factor
        # if reduction is 'none', then do nothing, otherwise raise an error
        elif reduction != 'none':
            raise ValueError('avg_factor can not be used with reduction="sum"')
    return loss

1.4. 总结

?基本上mmdet所有损失的计算流程就上述过程,在使用L1 Loss时,不必关心那么多超参,直接build loss然后传入pred和target即可,其余参数基本默认即可。

总结

?未完待续…

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-29 19:04:19  更:2022-06-29 19:07:09 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 8:46:30-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计