######################################################################################
###################打印IOU为0.5时的每类bbox大于minsize的Precision&&Recall###############
import json
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from copy import deepcopy
class Dt:
def __init__(self, img_id, score, bbox, cat):
"""
single dt
:param img_id:
:param score: a list of score
:param bbox: bbox
"""
self.img_id = img_id
self.score = score
self.bbox = bbox
self.cat = cat
self.is_match = False
class Gt:
def __init__(self, img_id, bbox, cat):
"""
:param img_id:
:param bbox:
"""
self.img_id = img_id
self.bbox = bbox
self.cat = cat
self.is_match = False
def compute_iou(bbox1, bbox2):
x1, y1, w1, h1 = bbox1
x2, y2, w2, h2 = bbox2
s_sum = w1 * h1 + w2 * h2
left = max(x1, x2)
right = min(x1 + w1, x2 + w2)
top = max(y1, y2)
bottom = min(y1 + h1, y2 + h2)
if left >= right or top >= bottom:
return 0
intersect = (right - left) * (bottom - top)
return intersect / (s_sum - intersect)
def prepare_gt_and_dt(gt_dict, dt_dict, c_id, min_det_size):
"""
:param gt_path: gt path
:param dt_path: dt path
:param c_id: category_id
:return: gt dict and dt list
"""
gt_objs = gt_dict["annotations"]
gt_objs = [gt_obj for gt_obj in gt_objs if
gt_obj["category_id"] == c_id and min(gt_obj["bbox"][2:]) >= min_det_size]
gt_obj_num = len(gt_objs)###实际gt中的bbox
gt_dict_obj = defaultdict(list) # key: image_id value: a list of Gt #构建一个默认values为list的字典
for gt_obj in gt_objs:
gt_dict_obj[gt_obj["image_id"]].append(
Gt(img_id=gt_obj["image_id"], bbox=gt_obj["bbox"], cat=gt_obj["category_id"]))
image_id_set = set([im["id"] for im in gt_dict["images"]])
dt_objs = [dt_obj for dt_obj in dt_dict
if dt_obj["category_id"] == c_id and dt_obj["image_id"] in image_id_set and min(
dt_obj["bbox"][2:]) >= min_det_size]
dt_list_obj = list()
for dt_obj in dt_objs:
dt_list_obj.append(
Dt(img_id=dt_obj["image_id"], score=dt_obj["score"], bbox=dt_obj["bbox"], cat=dt_obj["category_id"]))
return gt_dict_obj, dt_list_obj, gt_obj_num
def run_match(dt_list, gt_imgid2list_dict, iou_th):####核心代码,通过设置Gt类的self.ismatch来实现在IOU都达标的情况下选择分数最高的bbx作为TP
"""
:param dt_list: a list of Dt class
:param gt_dict: key: image_id value: a list of Gt class
:return:
"""
dt_list.sort(key=lambda x: x.score, reverse=True)###很重要
# print("matching dt to gts...")
for single_dt in tqdm(dt_list):
img_id = single_dt.img_id
if img_id in gt_imgid2list_dict:
max_iou = 0
max_index = -1
for index, gt in enumerate(gt_imgid2list_dict[img_id]):
if not gt.is_match:
cur_iou = compute_iou(single_dt.bbox, gt.bbox)
if cur_iou >= iou_th and max_iou < cur_iou:
max_index = index
max_iou = cur_iou
if max_index >= 0:
single_dt.is_match = True
gt_imgid2list_dict[img_id][max_index].is_match = True
def get_recalls_and_precisions(dt_list, gt_dict, gt_obj_num, generate_badcase=False, dt_badcase=None, gt_badcase=None):
"""
recall: tp / gt_all; precision: tp/ dt_all
:param dt_list: a list of Dt obj
:param gt_dict: key->imgId, value->a list of Gt obj
:param gt_obj_num: gt obj num in total
:return:
"""
assert (generate_badcase == False and dt_badcase is None and gt_badcase is None) or (
generate_badcase == True and dt_badcase is not None and gt_badcase is not None), 'wrong~'
is_match = [dt.is_match for dt in dt_list]
scores = [dt.score for dt in dt_list]
tp_list = []
cur_tp = 0
for index, match in enumerate(is_match):
if match:
cur_tp += 1
else:
if generate_badcase:
dt_badcase[dt_list[index].img_id].append(deepcopy(dt_list[index]))
tp_list.append(cur_tp)
if generate_badcase:
for img_id in gt_dict:
# gt_badcase[img_id].extend([gt.bbox for gt in gt_dict[img_id] if not gt.is_match])
gt_badcase[img_id].extend([gt for gt in gt_dict[img_id] if not gt.is_math])
recalls = [(tp / gt_obj_num) * 100 for tp in tp_list]
precisions = [tp / (index + 1) * 100 for index, tp in enumerate(tp_list)]
# precisions = ['{}, {}'tp, (index + 1) for index, tp in enumerate(tp_list)]
return recalls, precisions, scores
def find_index_by_thd_nearest(scores,thd):
ls_diff=abs(np.array(scores)-thd).tolist()
ind=ls_diff.index(min(ls_diff))
return ind
if __name__ == '__main__':
iou=0.5
min_size=0
gtpath="D:/DataSet/dwfan/lc/Detection_lc_parking_6class_platform/label/gt.coco"
resPath="D:/DataSet/dwfan/lc/Detection_lc_parking_6class_platform/result.json"
f_gt = open(gtpath, "r")
f_dt=open(resPath,"r")
data_gt=json.load(f_gt)
data_dt=json.load(f_dt)
confidence_thds=[0.55,0.5,0.5,0.5,0.5,0.5]
for i in data_gt["categories"]:
gt_dict_list, dt_list, gt_obj_num = prepare_gt_and_dt(data_gt, data_dt, i["id"], min_size) ##gt_dict_list是一个values为列表的字典,gt_list是一个列表
run_match(dt_list, gt_dict_list, iou)
recalls, precisions, scores = get_recalls_and_precisions(dt_list, gt_dict_list, gt_obj_num)
index=find_index_by_thd_nearest(scores, confidence_thds[i["id"] - 1])
score,recall, precision = scores[index],recalls[index], precisions[index]
print(score,recall,precision)
|