1. 定义一个自己的类
在csnmemde.py中,导入mmaction.models.builder 中的HEADS,使用HEADS注册器写在class上面。
还定义了一个mmdet_imported用来最后的一步register_module() #最后一步的作用不晓得。
# csnmemde.py
from mmaction.models.builder import HEADS
try:
from mmdet.models import BACKBONES as MMDET_BACKBONES # 定义backbone时用到这句
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
mmdet_imported = False
@HEADS.register_module()
class ResNetCSNMem(nn.Module):
def __init__(self, chnum_in, mem_dim, feature_num,
feature_num_2, feature_num_x2, feature_num_x4,
feature_num_x6, feature_num_x8, shrink_thres=0.0025):
super(ResNetCSNMem, self).__init__()
print('ResNetCSNCov3DMem')
self.chnum_in = chnum_in # 通道数
self.feature_num = feature_num
self.feature_num_2 = feature_num_2
self.feature_num_x2 = feature_num_x2
self.feature_num_x4 = feature_num_x4
self.feature_num_x6 = feature_num_x6
self.feature_num_x8 = feature_num_x8
if mmdet_imported:
MMDET_SHARED_HEADS.register_module()(ResNetCSNMem)
2.把ResNetCSNMem类放到mmaction的包中
在 mmacyion2/mmaction/models/head 的目录下添加csnmemde.py,或者可以直接在改目录下编辑代码。
3. 在head的__init__.py中添加ResNetCSNMem类
from .x3d_head import X3DHead
from .csnmemde import ResNetCSNMem
__all__ = [
'TSNHead', 'I3DHead', 'BaseHead', 'TSMHead', 'SlowFastHead', 'SSNHead',
'TPNHead', 'AudioTSNHead', 'X3DHead', 'BBoxHeadAVA', 'AVARoIHead',
'FBOHead', 'LFBInferHead', 'TRNHead', 'TimeSformerHead', 'ACRNHead',
'STGCNHead', 'ResNetCSNMem'
]
3.加载model
定义完一个新的类,第一次用到该类的时候,要确保重新activate了对应的虚拟环境,?重新activate了对应的虚拟环境,?重新activate了对应的虚拟环境, 这样这个新类才会注册到mmaction全局。
from mmaction.models import build_head
from mmcv import Config
cfg = Config.fromfile('config/csncfg.py')
memde = build_head(cfg.model.cls_head)
判断是否将新的类注册到全局,打印其HEADS看一下,注册成功!
from mmaction.models import HEADS
HEADS
?
?对于mmdet的注册方法一样,还可以自己定义BACKBONE, NECK等。这里完成了Registry的部分,如果用这个类,还需要定义config,使用时builder一下。
?
|