IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 目标检测中评估每类目标在IOU为0.5的Precision&&Recall -> 正文阅读

[人工智能]目标检测中评估每类目标在IOU为0.5的Precision&&Recall

######################################################################################
###################打印IOU为0.5时的每类bbox大于minsize的Precision&&Recall###############
import json
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from copy import deepcopy

class Dt:
    def __init__(self, img_id, score, bbox, cat):
        """
        single dt
        :param img_id:
        :param score:   a list of score
        :param bbox:  bbox
        """
        self.img_id = img_id
        self.score = score
        self.bbox = bbox
        self.cat = cat
        self.is_match = False


class Gt:
    def __init__(self, img_id, bbox, cat):
        """

        :param img_id:
        :param bbox:
        """
        self.img_id = img_id
        self.bbox = bbox
        self.cat = cat
        self.is_match = False


def compute_iou(bbox1, bbox2):
    x1, y1, w1, h1 = bbox1
    x2, y2, w2, h2 = bbox2
    s_sum = w1 * h1 + w2 * h2

    left = max(x1, x2)
    right = min(x1 + w1, x2 + w2)
    top = max(y1, y2)
    bottom = min(y1 + h1, y2 + h2)

    if left >= right or top >= bottom:
        return 0
    intersect = (right - left) * (bottom - top)

    return intersect / (s_sum - intersect)


def prepare_gt_and_dt(gt_dict, dt_dict, c_id, min_det_size):
    """
    :param gt_path: gt path
    :param dt_path: dt path
    :param c_id: category_id
    :return: gt dict  and  dt list
    """

    gt_objs = gt_dict["annotations"]
    gt_objs = [gt_obj for gt_obj in gt_objs if
               gt_obj["category_id"] == c_id and min(gt_obj["bbox"][2:]) >= min_det_size]

    gt_obj_num = len(gt_objs)###实际gt中的bbox

    gt_dict_obj = defaultdict(list) # key: image_id value: a list of Gt #构建一个默认values为list的字典
    for gt_obj in gt_objs:
        gt_dict_obj[gt_obj["image_id"]].append(
            Gt(img_id=gt_obj["image_id"], bbox=gt_obj["bbox"], cat=gt_obj["category_id"]))

    image_id_set = set([im["id"] for im in gt_dict["images"]])


    dt_objs = [dt_obj for dt_obj in dt_dict
               if dt_obj["category_id"] == c_id and dt_obj["image_id"] in image_id_set and min(
            dt_obj["bbox"][2:]) >= min_det_size]

    dt_list_obj = list()

    for dt_obj in dt_objs:
        dt_list_obj.append(
            Dt(img_id=dt_obj["image_id"], score=dt_obj["score"], bbox=dt_obj["bbox"], cat=dt_obj["category_id"]))

    return gt_dict_obj, dt_list_obj, gt_obj_num


def run_match(dt_list, gt_imgid2list_dict, iou_th):####核心代码,通过设置Gt类的self.ismatch来实现在IOU都达标的情况下选择分数最高的bbx作为TP
    """
    :param dt_list:  a list of Dt class
    :param gt_dict:  key: image_id  value: a list of Gt class
    :return:
    """
    dt_list.sort(key=lambda x: x.score, reverse=True)###很重要
    # print("matching dt to gts...")
    for single_dt in tqdm(dt_list):
        img_id = single_dt.img_id
        if img_id in gt_imgid2list_dict:
            max_iou = 0
            max_index = -1
            for index, gt in enumerate(gt_imgid2list_dict[img_id]):
                if not gt.is_match:
                    cur_iou = compute_iou(single_dt.bbox, gt.bbox)
                    if cur_iou >= iou_th and max_iou < cur_iou:
                        max_index = index
                        max_iou = cur_iou
            if max_index >= 0:
                single_dt.is_match = True
                gt_imgid2list_dict[img_id][max_index].is_match = True


def get_recalls_and_precisions(dt_list, gt_dict, gt_obj_num, generate_badcase=False, dt_badcase=None, gt_badcase=None):
    """
    recall: tp / gt_all;  precision:  tp/ dt_all
    :param dt_list: a list of Dt obj
    :param gt_dict: key->imgId, value->a list of Gt obj
    :param gt_obj_num: gt obj num in total
    :return:
    """

    assert (generate_badcase == False and dt_badcase is None and gt_badcase is None) or (
                generate_badcase == True and dt_badcase is not None and gt_badcase is not None), 'wrong~'

    is_match = [dt.is_match for dt in dt_list]
    scores = [dt.score for dt in dt_list]
    tp_list = []
    cur_tp = 0
    for index, match in enumerate(is_match):
        if match:
            cur_tp += 1
        else:
            if generate_badcase:
                dt_badcase[dt_list[index].img_id].append(deepcopy(dt_list[index]))
        tp_list.append(cur_tp)
    if generate_badcase:
        for img_id in gt_dict:
            # gt_badcase[img_id].extend([gt.bbox for gt in gt_dict[img_id] if not gt.is_match])
            gt_badcase[img_id].extend([gt for gt in gt_dict[img_id] if not gt.is_math])

    recalls = [(tp / gt_obj_num) * 100 for tp in tp_list]
    precisions = [tp / (index + 1) * 100 for index, tp in enumerate(tp_list)]
    # precisions = ['{}, {}'tp, (index + 1) for index, tp in enumerate(tp_list)]
    return recalls, precisions, scores


def find_index_by_thd_nearest(scores,thd):
    ls_diff=abs(np.array(scores)-thd).tolist()
    ind=ls_diff.index(min(ls_diff))
    return ind


if __name__ == '__main__':
    iou=0.5
    min_size=0
    gtpath="D:/DataSet/dwfan/lc/Detection_lc_parking_6class_platform/label/gt.coco"
    resPath="D:/DataSet/dwfan/lc/Detection_lc_parking_6class_platform/result.json"
    f_gt = open(gtpath, "r")
    f_dt=open(resPath,"r")
    data_gt=json.load(f_gt)
    data_dt=json.load(f_dt)
    confidence_thds=[0.55,0.5,0.5,0.5,0.5,0.5]
    for i in data_gt["categories"]:
        gt_dict_list, dt_list, gt_obj_num = prepare_gt_and_dt(data_gt, data_dt, i["id"], min_size) ##gt_dict_list是一个values为列表的字典,gt_list是一个列表
        run_match(dt_list, gt_dict_list, iou)
        recalls, precisions, scores = get_recalls_and_precisions(dt_list, gt_dict_list, gt_obj_num)
        index=find_index_by_thd_nearest(scores, confidence_thds[i["id"] - 1])
        score,recall, precision = scores[index],recalls[index], precisions[index]
        print(score,recall,precision)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-22 18:36:57  更:2022-04-22 18:39:01 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 10:20:55-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码