文件位置:CenterFusion-master/experiments/test.sh 和 CenterFusion-master/src/test.py 文件作用:CenterFusion 项目验证的执行过程 注意:本文中的代码都是 CenterFusion 原始代码,一些参数没有修改
一、test.sh 脚本
- 在 README.md 中训练模型的命令是:
bash experiments/test.sh - 首先执行的就是 test.sh 脚本
- 在脚本中
--参数 值 表示可选参数
export CUDA_VISIBLE_DEVICES=1
cd src
python test.py ddd \
--exp_id centerfusion \
'''
项目名称
'''
--dataset nuscenes \
'''
设置 nuscenes 数据集
'''
--val_split mini_val \
'''
验证集
'''
--run_dataset_eval \
'''
在 eval 中使用数据集特定的计算函数
'''
--num_workers 4 \
'''
4 线程
'''
--nuscenes_att \
--velocity \
--gpus 0 \
'''
gpu 索引号
'''
--pointcloud \
'''
雷达点云
'''
--radar_sweeps 3 \
'''
点云图中雷达扫瞄 3 次
'''
--max_pc_dist 60.0 \
'''
移除 max_pc_dist 以外的雷达点
'''
--pc_z_offset -0.0 \
'''
向 z 方向升起所有雷达,高度为 -0.0
'''
--load_model ../models/centerfusion_e60.pth \
'''
导入模型
'''
--flip_test \
'''
翻转数据增加
'''
二、test.py 文件
if __name__ == '__main__':
opt = opts().parse()
'''
调用 opts.py 中的 parse() 函数
在 CenterFusion/src/lib/opts.py 第 305 行
'''
if opt.not_prefetch_test:
test(opt)
else:
prefetch_test(opt)
'''
由于在 test.sh 中并没有添加参数 not_prefetch_test,所以 opt.not_prefetch_test = False
最后只执行了 else 中的语句,调用 prefetch_test() 函数
'''
def prefetch_test(opt):
if not opt.not_set_cuda_env:
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
'''
由于在 test.sh 中没有添加参数 not_set_cuda_env,所以 opt.not_set_cuda_env = False,再 not 取反
执行该 if 语句
gpus_str 是 opt 中的一个 GPU 索引号字符串,如:'0,1'
这里是为了给系统添加 cuda 索引号
'''
Dataset = dataset_factory[opt.test_dataset]
'''
设置数据集对象 nuScenes,Dataset 是一个 nuScenes 类
dataset_factory 在 CenterFusion/src/lib/dataset/dataset_factory.py 第 20 行
nuScenes 对象定义在 CenterFusion/src/lib/dataset/datasets/nuscenes.py 中
'''
opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
'''
设置一些配置信息
update_dataset_info_and_set_heads() 函数在 opys.py 第 458 行
'''
Logger(opt)
'''
新建一个 Logger 对象,记录配置信息
Logger 对象在 CenterFusion/src/lib/logger.py
'''
split = 'val' if not opt.trainval else 'test'
'''
在 test.sh 中没有添加参数 trainval,所以 opt.trainval = False,再 not 取反,为 True
所以 split = 'val'
'''
if split == 'val':
split = opt.val_split
'''
val_split 在 opts.py 中的默认值为 'val',但在 test.sh 中添加了该参数,并赋值为 'mini_val'
所以 split 的值为 'mini_val'
'''
dataset = Dataset(opt, split)
'''
传递参数 opt(配置信息)、split(数据集名称)
'''
detector = Detector(opt)
'''
Detector 类定义在 CenterFusion/src/lib/detector.py 中
'''
if opt.load_results != '':
load_results = json.load(open(opt.load_results, 'r'))
for img_id in load_results:
for k in range(len(load_results[img_id])):
if load_results[img_id][k]['class'] - 1 in opt.ignore_loaded_cats:
load_results[img_id][k]['score'] = -1
else:
load_results = {}
'''
load_results 默认值为 ''
所以执行了 else 语句
'''
data_loader = torch.utils.data.DataLoader(
PrefetchDataset(opt, dataset, detector.pre_process),
batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
'''
torch.utils.data.DataLoader 是一个数据读取的一个接口,参数:
PrefetchDataset(opt, dataset, detector.pre_process):加载数据的数据集
batch_size (int, optional):每个 batch 加载多少个样本(默认: 1)
shuffle (bool, optional):设置为 True 时会在每个 epoch 重新打乱数据(默认: False)
num_workers (int, optional):用多少个子进程加载数据。0 表示数据将在主进程中加载(默认: 0)
pin_memory (bool, optional):设置 pin_memory=True,则意味着生成的 Tensor 数据最开始是属于内存中的锁页内存,
这样将内存的 Tensor 转义到 GPU 的显存就会更快一些
PrefetchDataset 类,在上面,这个类继承了 torch.utils.data.Dataset 类,表示自定义了数据读取方式
最后返回一个列表给 data_loader,其中有图片 id 以及 tensor 格式的图片数据
'''
results = {}
num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
'''
num_iters 默认值为 -1,所以 num_iters = data_loader 列表的长度
'''
bar = Bar('{}'.format(opt.exp_id), max=num_iters)
'''
定义了一个进度条,如:centerfusion |### | 3/10
进度条名称为:centerfusion
进度条最大值为:num_iters
'''
time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge', 'track']
avg_time_stats = {t: AverageMeter() for t in time_stats}
'''
AverageMeter 类定义在 CenterFusion/src/lib/utils/utils.py 中
给 time_stats 中的每一个属性赋值为 AverageMeter 类
'''
if opt.use_loaded_results:
for img_id in data_loader.dataset.images:
results[img_id] = load_results['{}'.format(img_id)]
num_iters = 0
'''
在 test.sh 中没有添加参数 use_loaded_results ,所以值为 False
没有执行该 if 语句
'''
for ind, (img_id, pre_processed_images) in enumerate(data_loader):
'''
enumerate() 函数:用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列
同时列出数据和数据下标
其中 ind 表示对 data_loader 列表位置的计数
img_id 为图片数据的 id
pre_processed_images 为 tensor 格式的图片数据
'''
if ind >= num_iters:
break
'''
如果遍历完 data_loader 就结束
'''
if opt.tracking and ('is_first_frame' in pre_processed_images):
if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
pre_processed_images['meta']['pre_dets'] = \
load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
else:
print()
print('No pre_dets for', int(img_id.numpy().astype(np.int32)[0]),
'. Use empty initialization.')
pre_processed_images['meta']['pre_dets'] = []
detector.reset_tracking()
print('Start tracking video', int(pre_processed_images['video_id']))
'''
由于在 test.sh 中没有添加参数 tracking,所以 opt.tracking = False
没有执行该 if 语句
'''
if opt.public_det:
if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
pre_processed_images['meta']['cur_dets'] = \
load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
else:
print('No cur_dets for', int(img_id.numpy().astype(np.int32)[0]))
pre_processed_images['meta']['cur_dets'] = []
'''
由于在 test.sh 中没有添加参数 public_det,所以 opt.public_det = False
没有执行该 if 语句
'''
ret = detector.run(pre_processed_images)
'''
run() 函数在 CenterFusion/src/lib/detector.py 第 56 行
对 tensor 格式的图片数据进行检测,并返回检测结果
'''
results[int(img_id.numpy().astype(np.int32)[0])] = ret['results']
'''
其中 img_id.numpy().astype(np.int32) 是将 img_id 强制转换成 int32 型的数据
这里是为了记录对应图片数据的检测结果,results[图片的索引号] = ret['result']
'''
Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
'''
给 bar 添加一些显示字符串
显示 ind、num_iters、bar.elapsed_td、bar.eta_td 的值
'''
for t in avg_time_stats:
avg_time_stats[t].update(ret[t])
Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
t, tm = avg_time_stats[t])
'''
update() 函数在 CenterFusion/src/lib/utils/utils.py 第 18 行
计算 ret 中每个属性的平均值和当前值,并将其添加后 bar 的后面显示在屏幕上
'''
if opt.print_iter > 0:
if ind % opt.print_iter == 0:
print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
else:
bar.next()
'''
print_iter 默认值为 0,所以执行 else 语句
打印进度条到屏幕上
'''
bar.finish()
'''
进度条完成
'''
if opt.save_results:
print('saving results to', opt.save_dir + '/save_results_{}{}.json'.format(
opt.test_dataset, opt.dataset_version))
json.dump(_to_list(copy.deepcopy(results)),
open(opt.save_dir + '/save_results_{}{}.json'.format(
opt.test_dataset, opt.dataset_version), 'w'))
'''
在 test.sh 中没有添加 save_results 参数,所以 opt.save_results = False
没有执行该 if 语句
'''
dataset.run_eval(results, opt.save_dir, n_plots=opt.eval_n_plots,
render_curves=opt.eval_render_curves)
'''
对结果进行检测评估,评估结果保存在 ~/CenterFusion/src/lib/../../exp/ddd/centerfusion/nuscenes_eval_det_output_mini_val 下
run_eval() 函数在 CenterFusion/src/lib/dataset/datasets/nuscenes.py 第 272 行
results :图片数据的检测结果
save_dir :保存路径为 ~/CenterFusion/src/lib/../../exp/ddd/centerfusion
eval_n_plots :默认值为 0
eval_render_curves :渲染和保存评价曲线,在 test.sh 中没有添加该参数,则为 False
'''
def run_eval(self, results, save_dir, n_plots=10, render_curves=False):
task = 'tracking' if self.opt.tracking else 'det'
'''
由于 test.sh 中没有添加参数 tracking,所以 opt.tracking 的值为 False
所以 task = 'det'
'''
split = self.opt.val_split
'''
split = 'mini_val'
'''
version = 'v1.0-mini' if 'mini' in split else 'v1.0-trainval'
'''
version = 'v1.0-mini'
'''
self.save_results(results, save_dir, task, split)
'''
保存结果为 json 文件
'''
render_curves = 1 if render_curves else 0
'''
render_curves = 0
'''
if task == 'det':
output_dir = '{}/nuscenes_eval_det_output_{}/'.format(save_dir, split)
'''
设置输出路径
'''
os.system('python ' + \
'tools/nuscenes-devkit/python-sdk/nuscenes/eval/detection/evaluate.py ' + \
'{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
'--output_dir {} '.format(output_dir) + \
'--eval_set {} '.format(split) + \
'--dataroot ../data/nuscenes/ ' + \
'--version {} '.format(version) + \
'--plot_examples {} '.format(n_plots) + \
'--render_curves {} '.format(render_curves))
'''
执行官网 evaluate.py 文件
对结果进行检测评估,并输出到 output_dir 路径下
'''
else:
output_dir = '{}/nuscenes_evaltracl__output/'.format(save_dir)
os.system('python ' + \
'tools/nuscenes-devkit/python-sdk/nuscenes/eval/tracking/evaluate.py ' + \
'{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
'--output_dir {} '.format(output_dir) + \
'--dataroot ../data/nuscenes/')
os.system('python ' + \
'tools/nuscenes-devkit/python-sdk-alpha02/nuscenes/eval/tracking/evaluate.py ' + \
'{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
'--output_dir {} '.format(output_dir) + \
'--dataroot ../data/nuscenes/')
return output_dir
|