def main(det_thresh=0.05, gallery_size=-1, ignore_cam_id=True, input_path=None):
#results_path = '/raid/ljp/code/chao_mmdetection/jobs/dcn_base_focal/'
# change here
results_path = '/home/yy1/2021/mmdetection-public/work_dirs/' + input_path
data_root='/home/yy1/2021/data/prw/PRW-v16.04.20/'
# 获取query标签信息,图片名字 bbox pid 一张图片只有一个人
probe_set = load_probes(data_root)
# 获取gallery标签信息,图片名字 bbox pid 图片上的所有人
gallery_set = gt_roidbs(data_root)
name_id = dict()
# 给图片名字弄上index
for i, gallery in enumerate(gallery_set):
name = gallery['im_name']
name_id[name] = i
# print(name_id)
# 检测结果 其中包括 bbox score feat(256)
with open(os.path.join(results_path, 'results_1000.pkl'), 'rb') as fid:
all_dets = pickle.load(fid)
# 从all取出 bbox score 给gallery_det ;取出 feat 给 gallery feat
gallery_det, gallery_feat = [], []
for det in all_dets:
gallery_det.append(det[0][:, :5])
if det[0].shape[0] > 0:
feat = normalize(det[0][:, 5:], axis=1)
else:
feat = det[0][:, 5:]
# feat = normalize(det[0][:, 5:], axis=1)
gallery_feat.append(feat)
# 得到 query feat 即从gallery feat 中取
# 同样名字的图片,在一张图中,query 只有一个人,但是gallery 有多个人
# 通过querybox与 detbox 的IOU判断query的那个人在gallery 中的哪个位置,然后取那个人的feat
probe_feat = []
for probe in probe_set:
name = probe['im_name']
query_gt_box = probe['boxes'][0]
id = name_id[name]
det = gallery_det[id]
feat = gallery_feat[id]
iou, iou_max, nmax = get_max_iou(det, query_gt_box)
if iou_max < 0.1:
print("not detected", name, iou_max)
feat = feat[nmax]
probe_feat.append(feat)
# gallery_det, gallery_feat = [], []
# for det in all_dets:
# det[0] = det[0][det[0][:, 4]>thresh]
# gallery_det.append(det[0][:, :5])
# if det[0].shape[0] > 0:
# feat = normalize(det[0][:, 5:], axis=1)
# else:
# feat = det[0][:, 5:]
# feat = normalize(det[0][:, 5:], axis=1)
# gallery_feat.append(feat)
search_performance_calc(gallery_set, probe_set, gallery_det, gallery_feat, probe_feat, det_thresh, gallery_size, ignore_cam_id)
# @jit(forceobj=True)
def search_performance_calc(gallery_set, probe_set,
gallery_det, gallery_feat, probe_feat,
det_thresh=0.5, gallery_size=-1, ignore_cam_id=True):
assert len(gallery_set) == len(gallery_det)
assert len(gallery_set) == len(gallery_feat)
assert len(probe_set) == len(probe_feat)
gt_roidb = gallery_set
name_to_det_feat = {}
# 去掉得分小的,不是人的det 得到 name_to_det_feat 字典
for gt, det, feat in zip(gt_roidb, gallery_det, gallery_feat):
name = gt['im_name']
pids = gt['gt_pids']
cam_id = gt['cam_id']
scores = det[:, 4].ravel()
inds = np.where(scores >= det_thresh)[0]
if len(inds) > 0:
name_to_det_feat[name] = (det[inds], feat[inds], pids, cam_id)
aps = []
accs = []
topk = [1, 5, 10]
# ret = {'image_root': gallery_set.data_path, 'results': []}
# 遍历每一个query图像
for i in range(len(probe_set)):
y_true, y_score = [], []
imgs, rois = [], []
count_gt, count_tp = 0, 0
feat_p = probe_feat[i].ravel()
probe_imname = probe_set[i]['im_name']
probe_roi = probe_set[i]['boxes']
probe_pid = probe_set[i]['gt_pids']
probe_cam = probe_set[i]['cam_id']
# Find all occurence of this probe
# 得到gallery中包含query人物的图像,但是不能和query的图像是一张重复的 例如gallery 中有16张有该query人物
gallery_imgs = []
for x in gt_roidb:
if probe_pid in x['gt_pids'] and x['im_name'] != probe_imname:
gallery_imgs.append(x)
# 上述16 张图,每个图中该query的位置,共有16个位置 bbox
probe_gts = {}
for item in gallery_imgs:
probe_gts[item['im_name']] = \
item['boxes'][item['gt_pids'] == probe_pid]
# Construct gallery set for this probe
# 去掉和query同名图片的所有gallery 图像
if ignore_cam_id:
gallery_imgs = []
for x in gt_roidb:
if x['im_name'] != probe_imname:
gallery_imgs.append(x)
else:
gallery_imgs = []
for x in gt_roidb:
if x['im_name'] != probe_imname and x['cam_id'] != probe_cam:
gallery_imgs.append(x)
# # 1. Go through all gallery samples
# for item in testset.targets_db:
# Gothrough the selected gallery
# 遍历所有gallery 图像
for item in gallery_imgs:
gallery_imname = item['im_name']
# some contain the probe (gt not empty), some not
count_gt += (gallery_imname in probe_gts)
# compute distance between probe and gallery dets
if gallery_imname not in name_to_det_feat:
continue
det, feat_g, _, _ = name_to_det_feat[gallery_imname]
# get L2-normalized feature matrix NxD
assert feat_g.size == np.prod(feat_g.shape[:2])
# 维度变为256*N 的 展开的 N为一张图片中所有人 例如 4
feat_g = feat_g.reshape(feat_g.shape[:2])
# compute cosine similarities
# 得到四个得分
sim = feat_g.dot(feat_p).ravel()
# assign label for each det
label = np.zeros(len(sim), dtype=np.int32)
# 如果该文件名在query的16张图里面
if gallery_imname in probe_gts:
# 得到bbox
gt = probe_gts[gallery_imname].ravel()
w, h = gt[2] - gt[0], gt[3] - gt[1]
# 如果图片较小IOU阈值设大一点
iou_thresh = min(0.5, (w * h * 1.0) /
((w + 10) * (h + 10)))
#iou_thresh = min(0.3, (w * h * 1.0) /
# ((w + 10) * (h + 10)))
# 倒叙
inds = np.argsort(sim)[::-1]
# sim 由大到小
sim = sim[inds]
det = det[inds]
# only set the first matched det as true positive
for j, roi in enumerate(det[:, :4]):
if compute_iou(roi, gt) >= iou_thresh:
label[j] = 1
count_tp += 1
break
y_true.extend(list(label))
y_score.extend(list(sim))
imgs.extend([gallery_imname] * len(sim))
rois.extend(list(det))
# 2. Compute AP for this probe (need to scale by recall rate)
# 长度为 gally中的所有行人数量,一万多人
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
# count_tp 16张图中,det 与 gt box 的IOU 大于阈值
assert count_tp <= count_gt
recall_rate = count_tp * 1.0 / count_gt
ap = 0 if count_tp == 0 else \
average_precision_score(y_true, y_score) * recall_rate
aps.append(ap)
# 由大到小排序
inds = np.argsort(y_score)[::-1]
y_score = y_score[inds]
y_true = y_true[inds]
# 前k个,有该行人 则为1 ,若前一个没有这个行人,则top1为0,若前三个有,则top3为 1,若前三个中一个也不是则top3为0
accs.append([min(1, sum(y_true[:k])) for k in topk])
# # 4. Save result for JSON dump
# new_entry = {'probe_img': str(probe_imname),
# 'probe_roi': map(float, list(probe_roi.squeeze())),
# 'probe_gt': probe_gts,
# 'gallery': []}
# # only save top-10 predictions
# for k in range(10):
# new_entry['gallery'].append({
# 'img': str(imgs[inds[k]]),
# 'roi': map(float, list(rois[inds[k]])),
# 'score': float(y_score[k]),
# 'correct': int(y_true[k]),
# })
# ret['results'].append(new_entry)
print('search ranking:')
mAP = np.mean(aps)
print(' mAP = {:.2%}'.format(mAP))
accs = np.mean(accs, axis=0)
for i, k in enumerate(topk):
print(' top-{:2d} = {:.2%}'.format(k, accs[i]))
|