最近太多人问我如何绘制PR曲线了,我又很少及时看到你们的消息,在这里跟大家道个歉,我直接把代码贴出来,你们看着改参数就好。
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
ROOT_DIR = os.path.abspath("../../")
sys.path.append(ROOT_DIR)
from mrcnn import utils
import mrcnn.model as modellib
from samples.hpv import hpv
LOGS_DIR = os.path.join(ROOT_DIR, "logs")
DATASET_DIR = os.path.join(ROOT_DIR, "datasets/hpv")
config = hpv.NucleusInferenceConfig()
DEVICE = "/cpu:0"
TEST_MODE = "inference"
def get_ax(rows=1, cols=1, size=16):
fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
fig.tight_layout()
return ax
def text_save(filename, data):
file = open(filename, 'a')
for i in range(len(data)):
s = str(data[i]).replace('[','').replace(']','')
s = s.replace("'",'').replace(',','') +'\n'
file.write(s)
file.close()
print("保存txt文件成功")
dataset = hpv.NucleusDataset()
dataset.load_nucleus(DATASET_DIR, "stage1_test")
dataset.prepare()
print("Images: {}\nClasses: {}".format(len(dataset.image_ids), dataset.class_names))
with tf.device(DEVICE):
model = modellib.MaskRCNN(mode="inference", model_dir=LOGS_DIR, config=config)
weights_path = "/Mask_RCNN/logs/model1-120211011T1528/mask_rcnn_model1-1_0300.h5"
model.load_weights(weights_path, by_name=True)
image_ids = dataset.image_ids
APs = []
count1 = 0
for image_id in image_ids:
info = dataset.image_info[image_id]
print("image_id: ", image_id)
image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(dataset, config, image_id, use_mini_mask=False)
if count1 == 0:
save_box, save_class, save_mask = gt_bbox, gt_class_id, gt_mask
else:
save_box = np.concatenate((save_box, gt_bbox), axis=0)
save_class = np.concatenate((save_class, gt_class_id), axis=0)
save_mask = np.concatenate((save_mask, gt_mask), axis=2)
molded_images = np.expand_dims(modellib.mold_image(image, config), 0)
results = model.detect_molded(np.expand_dims(image, 0), np.expand_dims(image_meta, 0), verbose=1)
r = results[0]
if count1 == 0:
save_roi, save_id, save_score, save_m = r["rois"], r["class_ids"], r["scores"], r['masks']
else:
save_roi = np.concatenate((save_roi, r["rois"]), axis=0)
save_id = np.concatenate((save_id, r["class_ids"]), axis=0)
save_score = np.concatenate((save_score, r["scores"]), axis=0)
save_m = np.concatenate((save_m, r['masks']), axis=2)
count1 += 1
AP, precisions, recalls, overlaps = \
utils.compute_ap(save_box, save_class, save_mask,
save_roi, save_id, save_score, save_m)
print("AP: ", AP)
plt.plot(recalls, precisions, 'b', label='PR')
plt.title('Precision-Recall Curve. AP@50 = {:.3f}'.format(AP))
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend()
plt.show()
text_save('preci-model1.txt', precisions)
text_save('recall-model1.txt', recalls)
这个方法十分吃虚拟内存,就是这个脚本在哪个盘运行,就要设置大量虚拟内存,一般20图片挺快的,后来尝试修改代码,但是识别的结果有点不一样,后来也就没有继续研究了。 第一次写这个,不足之处请见谅。
|