配置文件:detr3d_res101_gridmask_cbgs.py
Figure1. DETR3D Architecture
本文主要是学习整理,结合DETR3D的模型结构与MMDetection3D的模型构建方法,首先介绍model dict的模型参数设置,然后介绍逐个介绍DETR3D中的子结构,过程中简单讲解mmdetection3d的模型构建流程。
model dict设置模型结构
model部分:定义按照backbone,neck,head的顺序设置模型参数。
model = dict(
type='Detr3D',
use_grid_mask=True,
img_backbone=dict(),
img_neck=dict(),
pts_bbox_head=dict(
type='Detr3DHead',
transformer=dict(),
bbox_coder=dict(),
positional_encoding=dict(),
loss_cls=dict(),
train_cfg=dict()))
)
MMDetection3D中dict初始化模型
MMDetection3D利用类之间的包含关系(head中包含transformer, transformer中包含decoder等)递归实例化每个组件, 在build_model后,通过registry这种注册机制,递归地实例化每个registry model。
具体如何初始化呢? 编者在第一次看源码时也遇到了问题,框架的抽象程度很高,但是逐步推进到底层源码,了解registry的注册、调用、初始化方式,可以清楚了解整个流程,这里以transformer与decoder为例:
- 在transformer中,这里作为父层级类,初始化下一级的子层级类decoder:
@TRANSFORMER.register_module()
class Detr3DTransformer(BaseModule):
def __init__(self,
num_feature_levels=4,
num_cams=6,
two_stage_num_proposals=300,
decoder=None,
**kwargs):
super(Detr3DTransformer, self).__init__(**kwargs)
self.decoder = build_transformer_layer_sequence(decoder)
- build_transfomer_layer_sequence:调用自统一的设置函数
def build_from_cfg(cfg, registry, default_args=None):
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
return obj_cls
总结来说:
- 初始化顺序:Detr3D->backbone->neck->head->head_transformer->head_transformer_decoder->last_component
- 初始化方式:从train.py的build_model开始,上一级通过调用build逐级调用build_from_cfg初始化各自的子结构,直到最底层的结构(BasModule)初始化完成,但是很多类都是继承自官方提供的各种结构(DETR3DHead继承自DETRHead),这种继承的子类通过super(childclass, self).__ init __(cfg)传入模型参数在父类中完成子结构初始化。
- 实例化方式:forward,与pytorch框架下相同
Backbone \ Neck
img_backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN2d', requires_grad=False),
norm_eval=True,
style='caffe',
dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, False, True, True)),
img_neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=4,
relu_before_extra_convs=True)
Head: Detr3DHead
head继承自mmdet3d提供的DetrHead
pts_bbox_head=dict(
type='Detr3DHead',
num_query=900,
num_classes=10,
in_channels=256,
sync_cls_avg_factor=True,
with_box_refine=True,
as_two_stage=False,
transformer=dict(),
bbox_coder=dict(
type='NMSFreeCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
pc_range=point_cloud_range,
max_num=300,
voxel_size=voxel_size,
num_classes=10),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
loss_bbox=dict(type='L1Loss', loss_weight=0.25),
loss_iou=dict(type='GIoULoss', loss_weight=0.0))
Detr3DTransformer
最底层的部分,完成了论文中的主要创新点部分:
- 通过传感器的转换矩阵完成queries预测reference points,并投影到feature map通过Bilinear Interpolation集合固定区域内的特征;
- 利用上一步的特征进行object queries refinement完成对queries的改善用于目标预测;
transformer=dict(
type='Detr3DTransformer',
decoder=dict(
type='Detr3DTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(
type='Detr3DCrossAtten',
pc_range=point_cloud_range,
num_points=1,
embed_dims=256)
],
feedforward_channels=512,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm','ffn', 'norm'))))
负责DETR3D的关键部分:reference points,特征抓取,queries refinement,objects cross attention
@TRANSFORMER.register_module()
class Detr3DTransformer(BaseModule):
def forward(self,
mlvl_feats,
query_embed,
reg_branches=None,
**kwargs):
"""
mlvl_feats (list(Tensor)): [bs, embed_dims, h, w].
query_embed (Tensor): [num_query, c].
mlvl_pos_embeds (list(Tensor)): [bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads
with_box_refine
"""
bs = mlvl_feats[0].size(0)
query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points
query = query.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=mlvl_feats,
query_pos=query_pos,
reference_points=reference_points,
reg_branches=reg_branches,
**kwargs)
inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out
Decoder
decoder部分关键在于如何完成论文中提出的object queries refinement,这里着重进行介绍:
Figure2. queries refinement
-
Decoder Block ? ?每一个decoder block流程:预测上一层queries对应的reference points后对queries进行refinement后,进行self-attention,作为下一个block输入:self.dropout(output) + inp_residual + pos_feat,即输出=原始输入+双线性插值特征+query位置特征 ? ?如何对提取后的多尺度特征进行处理呢? ? ?这里的提取的图像特征,从shape=(bs, c, num_query, num_cam, 1, len(num_feature_level))到shape=(bs, c, num_query),通过三个连续的sum(-1),将不同视角的相机特征,不同尺度的相机特征,进行求和,得到最终的图像特征,然后通过project将图像特征投影到与query同维度,最后直接求和作为下一个Decoder Block的输入。
output = output.sum(-1).sum(-1).sum(-1)
@ATTENTION.register_module()
class Detr3DCrossAtten(BaseModule):
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
query = query.permute(1, 0, 2)
bs, num_query, _ = query.size()
attention_weights = self.attention_weights(query).view(
bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)
reference_points_3d, output, mask = feature_sampling(
value, reference_points, self.pc_range, kwargs['img_metas'])
output = torch.nan_to_num(output)
mask = torch.nan_to_num(mask)
attention_weights = attention_weights.sigmoid() * mask
output = output * attention_weights
output = output.sum(-1).sum(-1).sum(-1)
output = output.permute(2, 0, 1)
output = self.output_proj(output)
pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2)
return self.dropout(output) + inp_residual + pos_feat
- Decoder Block feature sampling
这里介绍如何进行图像特征提取,这里对不同特征层的图像分别插值提取特征用来refine queries。
def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
lidar2img = []
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img)
reference_points = reference_points.clone()
reference_points_3d = reference_points.clone()
reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]
reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
B, num_query = reference_points.size()[:2]
num_cam = lidar2img.size(1)
reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)
reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
eps = 1e-5
mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
reference_points_cam = (reference_points_cam - 0.5) * 2
mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 1:2] > -1.0)
& (reference_points_cam[..., 1:2] < 1.0))
mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)
mask = torch.nan_to_num(mask)
sampled_feats = []
for lvl, feat in enumerate(mlvl_feats):
B, N, C, H, W = feat.size()
feat = feat.view(B*N, C, H, W)
reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
sampled_feats.append(sampled_feat)
sampled_feats = torch.stack(sampled_feats, -1)
sampled_feats = sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats))
return reference_points_3d, sampled_feats, mask
关于F.grid_sample()
|