Proposal and CropAndResize
1 custom plugin
- config.py 是为convert_to_uff 命令定义的。config.py文件中应该通过修改op字段将自定义层映射到TensorRT中的插件名称。插件参数的名称也应该与TensorRT插件所期望的完全匹配。如果config.py定义正确。NvUffParser将能够解析网络并使用正确的参数调用适当的插件。
- CropAndResized插件:根据产生的ROI坐标从特征图中裁剪出补丁,并将其调整为通用目标大小,例如7*7。输出张量用作“CropAndResize”插件后面的分类器的输入。
- Propsal 插件:从RPN中细化候选框,功能包括根据置信度选择顶框,执行NMS,最后选择在NMS之后具有最高置信度的顶框。
2 prerequisites
./download_model.sh 下载相关的数据和模型- patch uff converter
给UFF converter打补丁,修复UFF包中Softmax层问题,让UFF_ROOT表示python UFF包的根目录,比如/usr/lib/python2.7/dist-packages/uff ,让后使用以下命令修补程序
patch UFF_ROOT/converters/tensorflow/converter_functions.py < fix_softmax.patch
补丁文件 fix_softmax.patch是使用TensorRT 5.1 GA中的UFF软件包版本0.6.3生成的,在应用补丁之前,需要确认UFF软件版本也是0.6.3的。对于TensorRT 6.0,可以忽略它,因为它应该已经被修复了。
config.py 中添加新的插件
convert-to-uff -p config.py -O dense_class/Softmax -O dense_regress/BiasAdd -O proposal faster_rcnn.pb
3 run
- 上述步骤会在build/cmake/out文件夹中创建
sample_uff_faster_rcnn 可执行文件。
./sample_uff_faster_rcnn --datadir /data/uff_faster_rcnn -W 480 -H 272 -I 2016_1111_185016_003_00001_night_000441.ppm
INT8 mode:
./sample_uff_faster_rcnn --datadir /data/uff_faster_rcnn -i -W 480 -H 272 -I 2016_1111_185016_003_00001_night_000441.ppm --int8
4 code
config.py
import tensorflow as tf
import graphsurgeon as gs
CropAndResize = gs.create_plugin_node(
name='roi_pooling_conv_1/CropAndResize_new',
op="CropAndResize",
inputs=['activation_7/Relu', 'proposal'],
crop_height=7,
crop_width=7)
Proposal = gs.create_plugin_node(
name='proposal',
op='Proposal',
inputs=['rpn_out_class/Sigmoid', 'rpn_out_regress/BiasAdd'],
input_height=272,
input_width=480,
rpn_stride=16,
roi_min_size=1.0,
nms_iou_threshold=0.7,
pre_nms_top_n=6000,
post_nms_top_n=300,
anchor_sizes=[32.0, 64.0, 128.0],
anchor_ratios=[1.0, 0.5, 2.0])
namespace_plugin_map = {
"crop_and_resize_1/Reshape" : CropAndResize,
'crop_and_resize_1/CropAndResize' : CropAndResize,
"crop_and_resize_1/transpose" : CropAndResize,
"crop_and_resize_1/transpose_1" : CropAndResize
}
def preprocess(dynamic_graph):
dynamic_graph.append(Proposal)
dynamic_graph.remove(dynamic_graph.find_nodes_by_name('input_2'))
dynamic_graph.collapse_namespaces(namespace_plugin_map)
|