IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> DETR3D模型源码导读 & MMDetection3D构建流程 -> 正文阅读

[人工智能]DETR3D模型源码导读 & MMDetection3D构建流程

配置文件:detr3d_res101_gridmask_cbgs.py

在这里插入图片描述

Figure1. DETR3D Architecture

本文主要是学习整理,结合DETR3D的模型结构MMDetection3D的模型构建方法,首先介绍model dict的模型参数设置,然后介绍逐个介绍DETR3D中的子结构,过程中简单讲解mmdetection3d的模型构建流程。


model dict设置模型结构

model部分:定义按照backbone,neck,head的顺序设置模型参数。

# 此处省略关键参数,实际以具体的配置文件为准
model = dict(
    type='Detr3D',
    use_grid_mask=True,
    # resnet提取0,1,2,3层的特征
    img_backbone=dict(),
    img_neck=dict(),
    # transformer head定义,本层的dict所指代的类负责对包含在内的 下一层dict实体 进行实例化
    pts_bbox_head=dict(
        type='Detr3DHead',
        # head中只有decoder
        transformer=dict(),
        # loss,bbox,position_embedding
        bbox_coder=dict(), 
        positional_encoding=dict(),
        loss_cls=dict(),
    train_cfg=dict()))
    )

MMDetection3D中dict初始化模型

MMDetection3D利用类之间的包含关系(head中包含transformer, transformer中包含decoder等)递归实例化每个组件, 在build_model后,通过registry这种注册机制,递归地实例化每个registry model。

具体如何初始化呢? 编者在第一次看源码时也遇到了问题,框架的抽象程度很高,但是逐步推进到底层源码,了解registry的注册、调用、初始化方式,可以清楚了解整个流程,这里以transformer与decoder为例:

  1. 在transformer中,这里作为父层级类,初始化下一级的子层级类decoder:
@TRANSFORMER.register_module()
class Detr3DTransformer(BaseModule):
    def __init__(self,
                 num_feature_levels=4,
                 num_cams=6,
                 two_stage_num_proposals=300,
                 decoder=None,
                 **kwargs):
        super(Detr3DTransformer, self).__init__(**kwargs)
        # 初始化decoder
        self.decoder = build_transformer_layer_sequence(decoder)
  1. build_transfomer_layer_sequence:调用自统一的设置函数
def build_from_cfg(cfg, registry, default_args=None):
# obj_type:transformer
    obj_type = args.pop('type')
    if isinstance(obj_type, str):
        # get registry for dataset
        # 查询并获得registry注册好的decoder类
        obj_cls = registry.get(obj_type)
   return obj_cls

总结来说:

  1. 初始化顺序:Detr3D->backbone->neck->head->head_transformer->head_transformer_decoder->last_component
  2. 初始化方式:从train.py的build_model开始,上一级通过调用build逐级调用build_from_cfg初始化各自的子结构,直到最底层的结构(BasModule)初始化完成,但是很多类都是继承自官方提供的各种结构(DETR3DHead继承自DETRHead),这种继承的子类通过super(childclass, self).__ init __(cfg)传入模型参数在父类中完成子结构初始化
  3. 实例化方式:forward,与pytorch框架下相同

Backbone \ Neck

    img_backbone=dict(
        type='ResNet',
        # resnet101
        depth=101,
        # bottom-up结构特征图的C0,1,2,3
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
        style='caffe',
        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, True, True)),
    img_neck=dict(
        type='FPN',
        # FPN的输入channel
        in_channels=[256, 512, 1024, 2048],
        # 最终的四个特征图都是256维
        out_channels=256,
        start_level=1,
        add_extra_convs='on_output',
        num_outs=4,
        relu_before_extra_convs=True)

Head: Detr3DHead

head继承自mmdet3d提供的DetrHead

    pts_bbox_head=dict(
        type='Detr3DHead',
        num_query=900,
        num_classes=10,
        in_channels=256,
        sync_cls_avg_factor=True,
        with_box_refine=True,
        as_two_stage=False,
        # head中只有decoder
        transformer=dict(),
        # loss,bbox,position_embedding
        bbox_coder=dict(
            type='NMSFreeCoder',
            post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
            pc_range=point_cloud_range,
            max_num=300,
            voxel_size=voxel_size,
            num_classes=10), 
        positional_encoding=dict(
            type='SinePositionalEncoding',
            num_feats=128,
            normalize=True,
            offset=-0.5),
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=2.0),
        loss_bbox=dict(type='L1Loss', loss_weight=0.25),
        loss_iou=dict(type='GIoULoss', loss_weight=0.0))

Detr3DTransformer

在这里插入图片描述

最底层的部分,完成了论文中的主要创新点部分:

  1. 通过传感器的转换矩阵完成queries预测reference points,并投影到feature map通过Bilinear Interpolation集合固定区域内的特征;
  2. 利用上一步的特征进行object queries refinement完成对queries的改善用于目标预测;
	transformer=dict(
	    type='Detr3DTransformer',
	    decoder=dict(
	        type='Detr3DTransformerDecoder',
	        num_layers=6,
	        return_intermediate=True,
	        # 设置单个decoder layer参数
	        transformerlayers=dict(
	            type='DetrTransformerDecoderLayer',
	            attn_cfgs=[
	                dict(
	                    type='MultiheadAttention',
	                    embed_dims=256,
	                    num_heads=8,
	                    dropout=0.1),
	                dict(
	                    type='Detr3DCrossAtten',
	                    pc_range=point_cloud_range,
	                    num_points=1,
	                    embed_dims=256)
	            ],
	            feedforward_channels=512,
	            ffn_dropout=0.1,
	            operation_order=('self_attn', 'norm', 'cross_attn', 'norm','ffn', 'norm'))))
  • DETR3DTransformer:

负责DETR3D的关键部分:reference points,特征抓取,queries refinement,objects cross attention

@TRANSFORMER.register_module()
class Detr3DTransformer(BaseModule):
    def forward(self,
                mlvl_feats,
                query_embed,
                reg_branches=None,
                **kwargs):
        """
         mlvl_feats (list(Tensor)): [bs, embed_dims, h, w].
         query_embed (Tensor): [num_query, c].
         mlvl_pos_embeds (list(Tensor)): [bs, embed_dims, h, w].
         reg_branches (obj:`nn.ModuleList`): Regression heads 
         with_box_refine
        """
        bs = mlvl_feats[0].size(0)
        # 256 -> 128, 128
        query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
        # -1为保持原样
        query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
        query = query.unsqueeze(0).expand(bs, -1, -1)
        # query_pos作为输入通过reg_branches回归参考点对应的2d position
        reference_points = self.reference_points(query_pos)
        reference_points = reference_points.sigmoid()
        init_reference_out = reference_points

        # decoder
        query = query.permute(1, 0, 2)
        query_pos = query_pos.permute(1, 0, 2)
        # decoder
        inter_states, inter_references = self.decoder(
            query=query,
            key=None,
            value=mlvl_feats,
            query_pos=query_pos,
            reference_points=reference_points,
            reg_branches=reg_branches,
            **kwargs)

        inter_references_out = inter_references
        return inter_states, init_reference_out, inter_references_out

Decoder

decoder部分关键在于如何完成论文中提出的object queries refinement,这里着重进行介绍:
在这里插入图片描述

Figure2. queries refinement
  1. Decoder Block

    ? ?每一个decoder block流程:预测上一层queries对应的reference points后对queries进行refinement后,进行self-attention,作为下一个block输入:self.dropout(output) + inp_residual + pos_feat,即输出=原始输入+双线性插值特征+query位置特征

    ? ?如何对提取后的多尺度特征进行处理呢?
    ? ?这里的提取的图像特征,从shape=(bs, c, num_query, num_cam, 1, len(num_feature_level))到shape=(bs, c, num_query),通过三个连续的sum(-1),将不同视角的相机特征,不同尺度的相机特征,进行求和,得到最终的图像特征,然后通过project将图像特征投影到与query同维度,最后直接求和作为下一个Decoder Block的输入。

output = output.sum(-1).sum(-1).sum(-1)
@ATTENTION.register_module()
class Detr3DCrossAtten(BaseModule):
    def forward(self,
                query,
                key,
                value,
                residual=None,
                query_pos=None,
                key_padding_mask=None,
                reference_points=None,
                spatial_shapes=None,
                level_start_index=None,
                **kwargs):

        query = query.permute(1, 0, 2)
        bs, num_query, _ = query.size()
        attention_weights = self.attention_weights(query).view(
            bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)
        
        # 双线性插值
        reference_points_3d, output, mask = feature_sampling(
            value, reference_points, self.pc_range, kwargs['img_metas'])
        output = torch.nan_to_num(output)
        mask = torch.nan_to_num(mask)

        attention_weights = attention_weights.sigmoid() * mask
        output = output * attention_weights
        output = output.sum(-1).sum(-1).sum(-1) # sum后缩减三个维度:shape:[bs, c, num_query]
        output = output.permute(2, 0, 1) # [num_query, bs, c]
        
        output = self.output_proj(output) # (num_query, bs, embed_dims),将reference3d的dim转换到256
        # output作为fetch的feature,与经过encoder后的query、原始query直接相加作为refinement query
        pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2)

        return self.dropout(output) + inp_residual + pos_feat
  1. Decoder Block feature sampling
    这里介绍如何进行图像特征提取,这里对不同特征层的图像分别插值提取特征用来refine queries。
# 特征采样部分
# 特征采样部分, Input queries from different level. Each element has shape [bs, embed_dims, h, w] 也就是[4, bs, embed_dims, h, w]
def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
    lidar2img = []
    # lidar2img:3D坐标以lidar为中心,求出3D点到img的转换关系也就是求出lidar到img的转换关系
    for img_meta in img_metas:
        lidar2img.append(img_meta['lidar2img'])
    lidar2img = np.asarray(lidar2img)
    # N = 6,referrence_points:[bs, num_query, 3]
    lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
    reference_points = reference_points.clone()
    reference_points_3d = reference_points.clone()
    # recompute top-left(x,y) and bottom-right(x)
    reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]
    reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]
    reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]
    # reference_points [bs, num_query, 3]
    reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
    B, num_query = reference_points.size()[:2]
    # num_cam = 6
    num_cam = lidar2img.size(1)
    # from [b,1,num_query,4] to [b,num_cam,num_query, 4, 1]
    reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
    # shape:[b, num_cam, num_query, 4, 4]
    lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)

    # project 3d -> 2d
    # shape:[b, num_cam, num_query, 4]
    reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
    eps = 1e-5
    mask = (reference_points_cam[..., 2:3] > eps)
    # cam坐标归一化: reference_points_cam.shape:[b,num_cam,num_query,2]
    reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
        reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)
    # 0,1分别代表camera像素坐标系下的x,y坐标,并进行归一化
    reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
    reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
    reference_points_cam = (reference_points_cam - 0.5) * 2
    mask = (mask & (reference_points_cam[..., 0:1] > -1.0) 
                 & (reference_points_cam[..., 0:1] < 1.0) 
                 & (reference_points_cam[..., 1:2] > -1.0) 
                 & (reference_points_cam[..., 1:2] < 1.0))
    mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)
    mask = torch.nan_to_num(mask)
    sampled_feats = []
    # 对四个特征层分别求出线性插值后的feature,其中N为num_query, [4, bs, embed_dims, h, w]
    for lvl, feat in enumerate(mlvl_feats):
        B, N, C, H, W = feat.size() # (num_key, bs, embed_dims)
        # N=num_cam
        feat = feat.view(B*N, C, H, W)
        # [b,num_cam,num_query,2] -> [b, num_cam, num_query, 1, 2]
        reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
        # F.grid_sample return:[b*n,c,num_query,1]每个query对应着一个grid采样(bilinear incorparation)后返回的值
        sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
        # b,c,n_q,n,1
        sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
        sampled_feats.append(sampled_feat)
    # [b,n,c,num_query,len(mlvl_feats)]
    sampled_feats = torch.stack(sampled_feats, -1)
    sampled_feats = sampled_feats.view(B, C, num_query, num_cam,  1, len(mlvl_feats)) 
    return reference_points_3d, sampled_feats, mask

关于F.grid_sample()

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章           查看所有文章
加:2022-06-21 21:26:11  更:2022-06-21 21:28:35 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 3:35:34-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码