IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 关于Mask R-CNN 画PR曲线 -> 正文阅读

[人工智能]关于Mask R-CNN 画PR曲线

最近太多人问我如何绘制PR曲线了,我又很少及时看到你们的消息,在这里跟大家道个歉,我直接把代码贴出来,你们看着改参数就好。

################ 导入相关包 #####################
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
ROOT_DIR = os.path.abspath("../../")
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn import utils
import mrcnn.model as modellib
from samples.hpv import hpv    # 这里是我自己写的脚本  继承的参数 一般是nucleus继承过来的
##############  配置参数  ####
LOGS_DIR = os.path.join(ROOT_DIR, "logs")
DATASET_DIR = os.path.join(ROOT_DIR, "datasets/hpv")   #  数据集
config = hpv.NucleusInferenceConfig()
DEVICE = "/cpu:0"
TEST_MODE = "inference"
def get_ax(rows=1, cols=1, size=16):
    fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
    fig.tight_layout()
    return ax
def text_save(filename, data):#filename为写入CSV文件的路径,data为要写入数据列表.
    file = open(filename, 'a')
    for i in range(len(data)):
        s = str(data[i]).replace('[','').replace(']','')#去除[],这两行按数据不同,可以选择
        s = s.replace("'",'').replace(',','') +'\n'   #去除单引号,逗号,每行末尾追加换行符
        file.write(s)
    file.close()
    print("保存txt文件成功")

#####  加载测试集数据  #####
dataset = hpv.NucleusDataset()
dataset.load_nucleus(DATASET_DIR, "stage1_test")
dataset.prepare()
print("Images: {}\nClasses: {}".format(len(dataset.image_ids), dataset.class_names))

#####  导入模型  ####
with tf.device(DEVICE):
    model = modellib.MaskRCNN(mode="inference", model_dir=LOGS_DIR, config=config)
weights_path = "/Mask_RCNN/logs/model1-120211011T1528/mask_rcnn_model1-1_0300.h5"
model.load_weights(weights_path, by_name=True)
image_ids = dataset.image_ids

APs = []
count1 = 0
for image_id in image_ids:
    info = dataset.image_info[image_id]
    print("image_id: ", image_id)
    # ####重要步骤:获得测试图片的信息
    image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(dataset, config, image_id, use_mini_mask=False)
    # ###保存实际结果
    if count1 == 0:
        save_box, save_class, save_mask = gt_bbox, gt_class_id, gt_mask
    else:
        save_box = np.concatenate((save_box, gt_bbox), axis=0)
        save_class = np.concatenate((save_class, gt_class_id), axis=0)
        save_mask = np.concatenate((save_mask, gt_mask), axis=2)
    molded_images = np.expand_dims(modellib.mold_image(image, config), 0)
    # # 显示检测结果
    # results = model.detect_molded(np.expand_dims(image, 0), np.expand_dims(image_meta, 0), verbose=1)
    results = model.detect_molded(np.expand_dims(image, 0), np.expand_dims(image_meta, 0), verbose=1)
    r = results[0]
    # 保存预测结果
    if count1 == 0:
        save_roi, save_id, save_score, save_m = r["rois"], r["class_ids"], r["scores"], r['masks']
    else:
        save_roi = np.concatenate((save_roi, r["rois"]), axis=0)
        save_id = np.concatenate((save_id, r["class_ids"]), axis=0)
        save_score = np.concatenate((save_score, r["scores"]), axis=0)
        save_m = np.concatenate((save_m, r['masks']), axis=2)
    count1 += 1
# AP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox, gt_class_id, gt_mask, r['rois'], r['class_ids'], r['scores'], r['masks'])
# APs.append(AP)


# # 在阈值0.5到0.95之间每隔0.1显示AP值
# utils.compute_ap_range(gt_bbox_all, gt_class_id_all, gt_mask_all, pre_rois_all, pre_class_ids_all, pre_scores_all, pre_masks_all, verbose=1)
## 在图片中显示真实与预测之间的差异
# visualize.display_differences(image, gt_bbox, gt_class_id, gt_mask, r['rois'], r['class_ids'], r['scores'], r['masks'],
#                               dataset.class_names, ax=get_ax(), show_box=False, show_mask=False, iou_threshold=0.5, score_threshold=0.5)
# plt.show()

# ######绘制PR曲线######

AP, precisions, recalls, overlaps = \
        utils.compute_ap(save_box, save_class, save_mask,
                         save_roi, save_id, save_score, save_m)
print("AP: ", AP)
# print("mAP: ", np.mean(APs))

plt.plot(recalls, precisions, 'b', label='PR')
plt.title('Precision-Recall Curve. AP@50 = {:.3f}'.format(AP))
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend()
plt.show()
text_save('preci-model1.txt', precisions)
text_save('recall-model1.txt', recalls)

这个方法十分吃虚拟内存,就是这个脚本在哪个盘运行,就要设置大量虚拟内存,一般20图片挺快的,后来尝试修改代码,但是识别的结果有点不一样,后来也就没有继续研究了。
第一次写这个,不足之处请见谅。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-08 13:48:54  更:2021-12-08 13:50:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 0:46:09-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码