IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> AlignPS test_result_prw.py注释 -> 正文阅读

[人工智能]AlignPS test_result_prw.py注释

def main(det_thresh=0.05, gallery_size=-1, ignore_cam_id=True, input_path=None):
    #results_path = '/raid/ljp/code/chao_mmdetection/jobs/dcn_base_focal/'

    # change here
    results_path = '/home/yy1/2021/mmdetection-public/work_dirs/' + input_path
    data_root='/home/yy1/2021/data/prw/PRW-v16.04.20/'
    # 获取query标签信息,图片名字 bbox pid 一张图片只有一个人
    probe_set = load_probes(data_root)
    # 获取gallery标签信息,图片名字 bbox pid 图片上的所有人
    gallery_set = gt_roidbs(data_root)

    name_id = dict()
    # 给图片名字弄上index
    for i, gallery in enumerate(gallery_set):
        name = gallery['im_name']
        name_id[name] = i
    # print(name_id)

    # 检测结果 其中包括 bbox score feat(256)
    with open(os.path.join(results_path, 'results_1000.pkl'), 'rb') as fid:
        all_dets = pickle.load(fid)

    # 从all取出 bbox score 给gallery_det ;取出 feat 给 gallery feat
    gallery_det, gallery_feat = [], []
    for det in all_dets:
        gallery_det.append(det[0][:, :5])
        if det[0].shape[0] > 0:
            feat = normalize(det[0][:, 5:], axis=1)
        else:
            feat = det[0][:, 5:]
        # feat = normalize(det[0][:, 5:], axis=1)
        gallery_feat.append(feat)
    # 得到 query feat 即从gallery feat 中取
    # 同样名字的图片,在一张图中,query 只有一个人,但是gallery 有多个人
    # 通过querybox与 detbox 的IOU判断query的那个人在gallery 中的哪个位置,然后取那个人的feat
    probe_feat = []
    for probe in probe_set:
        name = probe['im_name']
        query_gt_box = probe['boxes'][0]
        id = name_id[name]
        det = gallery_det[id]
        feat = gallery_feat[id]

        iou, iou_max, nmax = get_max_iou(det, query_gt_box)
        if iou_max < 0.1:
            print("not detected", name, iou_max)
        feat = feat[nmax]
        probe_feat.append(feat)
    
    # gallery_det, gallery_feat = [], []
    # for det in all_dets:
        # det[0] = det[0][det[0][:, 4]>thresh]
        # gallery_det.append(det[0][:, :5])
        # if det[0].shape[0] > 0:
        #     feat = normalize(det[0][:, 5:], axis=1)
        # else:
        #     feat = det[0][:, 5:]
        # feat = normalize(det[0][:, 5:], axis=1)
        # gallery_feat.append(feat)
    
    search_performance_calc(gallery_set, probe_set, gallery_det, gallery_feat, probe_feat, det_thresh, gallery_size, ignore_cam_id)

# @jit(forceobj=True)
def search_performance_calc(gallery_set, probe_set,
                                gallery_det, gallery_feat, probe_feat,
                                det_thresh=0.5, gallery_size=-1, ignore_cam_id=True):

    assert len(gallery_set) == len(gallery_det)
    assert len(gallery_set) == len(gallery_feat)
    assert len(probe_set) == len(probe_feat)

    gt_roidb = gallery_set
    name_to_det_feat = {}
    # 去掉得分小的,不是人的det 得到 name_to_det_feat 字典
    for gt, det, feat in zip(gt_roidb, gallery_det, gallery_feat):
        name = gt['im_name']
        pids = gt['gt_pids']
        cam_id = gt['cam_id']
        scores = det[:, 4].ravel()
        inds = np.where(scores >= det_thresh)[0]
        if len(inds) > 0:
            name_to_det_feat[name] = (det[inds], feat[inds], pids, cam_id)

    aps = []
    accs = []
    topk = [1, 5, 10]
    # ret = {'image_root': gallery_set.data_path, 'results': []}
    # 遍历每一个query图像
    for i in range(len(probe_set)):
        y_true, y_score = [], []
        imgs, rois = [], []
        count_gt, count_tp = 0, 0

        feat_p = probe_feat[i].ravel()

        probe_imname = probe_set[i]['im_name']
        probe_roi = probe_set[i]['boxes']
        probe_pid = probe_set[i]['gt_pids']
        probe_cam = probe_set[i]['cam_id']

        # Find all occurence of this probe
        # 得到gallery中包含query人物的图像,但是不能和query的图像是一张重复的 例如gallery 中有16张有该query人物
        gallery_imgs = []
        for x in gt_roidb:
            if probe_pid in x['gt_pids'] and x['im_name'] != probe_imname:
                gallery_imgs.append(x)
        # 上述16 张图,每个图中该query的位置,共有16个位置 bbox
        probe_gts = {}
        for item in gallery_imgs:
            probe_gts[item['im_name']] = \
                item['boxes'][item['gt_pids'] == probe_pid]

        # Construct gallery set for this probe
        # 去掉和query同名图片的所有gallery 图像
        if ignore_cam_id:
            gallery_imgs = []
            for x in gt_roidb:
                if x['im_name'] != probe_imname:
                    gallery_imgs.append(x)
        else:
            gallery_imgs = []
            for x in gt_roidb:
                if x['im_name'] != probe_imname and x['cam_id'] != probe_cam:
                    gallery_imgs.append(x)

        # # 1. Go through all gallery samples
        # for item in testset.targets_db:
        # Gothrough the selected gallery
        # 遍历所有gallery 图像
        for item in gallery_imgs:
            gallery_imname = item['im_name']
            # some contain the probe (gt not empty), some not
            count_gt += (gallery_imname in probe_gts)
            # compute distance between probe and gallery dets
            if gallery_imname not in name_to_det_feat:
                continue
            det, feat_g, _, _ = name_to_det_feat[gallery_imname]
            # get L2-normalized feature matrix NxD
            assert feat_g.size == np.prod(feat_g.shape[:2])
            # 维度变为256*N 的 展开的 N为一张图片中所有人 例如 4
            feat_g = feat_g.reshape(feat_g.shape[:2])
            # compute cosine similarities
            # 得到四个得分
            sim = feat_g.dot(feat_p).ravel()
            # assign label for each det
            label = np.zeros(len(sim), dtype=np.int32)
            # 如果该文件名在query的16张图里面
            if gallery_imname in probe_gts:
                # 得到bbox
                gt = probe_gts[gallery_imname].ravel()
                w, h = gt[2] - gt[0], gt[3] - gt[1]
                # 如果图片较小IOU阈值设大一点
                iou_thresh = min(0.5, (w * h * 1.0) /
                                    ((w + 10) * (h + 10)))
                #iou_thresh = min(0.3, (w * h * 1.0) /
                #                    ((w + 10) * (h + 10)))
                # 倒叙
                inds = np.argsort(sim)[::-1]
                # sim 由大到小
                sim = sim[inds]
                det = det[inds]
                # only set the first matched det as true positive
                for j, roi in enumerate(det[:, :4]):
                    if compute_iou(roi, gt) >= iou_thresh:
                        label[j] = 1
                        count_tp += 1
                        break
            y_true.extend(list(label))
            y_score.extend(list(sim))
            imgs.extend([gallery_imname] * len(sim))
            rois.extend(list(det))

        # 2. Compute AP for this probe (need to scale by recall rate)
        # 长度为 gally中的所有行人数量,一万多人
        y_score = np.asarray(y_score)
        y_true = np.asarray(y_true)
        # count_tp 16张图中,det 与 gt box 的IOU 大于阈值
        assert count_tp <= count_gt
        recall_rate = count_tp * 1.0 / count_gt
        ap = 0 if count_tp == 0 else \
            average_precision_score(y_true, y_score) * recall_rate
        aps.append(ap)
        # 由大到小排序
        inds = np.argsort(y_score)[::-1]
        y_score = y_score[inds]
        y_true = y_true[inds]
        # 前k个,有该行人 则为1 ,若前一个没有这个行人,则top1为0,若前三个有,则top3为 1,若前三个中一个也不是则top3为0
        accs.append([min(1, sum(y_true[:k])) for k in topk])
        # # 4. Save result for JSON dump
        # new_entry = {'probe_img': str(probe_imname),
        #                 'probe_roi': map(float, list(probe_roi.squeeze())),
        #                 'probe_gt': probe_gts,
        #                 'gallery': []}
        # # only save top-10 predictions
        # for k in range(10):
        #     new_entry['gallery'].append({
        #         'img': str(imgs[inds[k]]),
        #         'roi': map(float, list(rois[inds[k]])),
        #         'score': float(y_score[k]),
        #         'correct': int(y_true[k]),
        #     })
        # ret['results'].append(new_entry)

    print('search ranking:')
    mAP = np.mean(aps)
    print('  mAP = {:.2%}'.format(mAP))
    accs = np.mean(accs, axis=0)
    for i, k in enumerate(topk):
        print('  top-{:2d} = {:.2%}'.format(k, accs[i]))

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-07 22:41:41  更:2022-04-07 22:42:35 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:29:26-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码