1.先把原来的mask标注图片从val.txt里面抽取出来
import sys
sys.path.append("..")
sys.path.insert(0, '.')
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import os
torch.set_grad_enabled(False)
np.random.seed(123)
val_path = "F:\\1207garbage_classification\\0505_garbage_optimize\\val_mask_ori\\"
os.makedirs(val_path,exist_ok=True)
data = []
i = 0
for line in open("F:\\1207garbage_classification\\0505_garbage_optimize\\val.txt", "r"):
data.append(line)
for img in data:
filename = img.split(",")
print(filename[1])
filename = filename[1].replace('\n', '')
a = Image.open("F:\\1207garbage_classification\\0505_garbage_optimize\\"+filename)
Image_copy = Image.Image.copy(a)
b = os.path.split(filename)[-1]
Image.Image.save(Image_copy, fp=val_path + b)
i += 1
print("一共验证集val多少图片:", i)
2.把验证集里面的原图取出来进行测试,demo出掩码图
import sys
sys.path.append("..")
sys.path.insert(0, '.')
import argparse
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import cv2
import time
import os
import lib.transform_cv2 as T
from lib.models import model_factory
from configs import set_cfg_from_file
from tqdm import tqdm
torch.set_grad_enabled(False)
np.random.seed(123)
sys.path.append('G:\\addwater0906\\')
val_path = "G:\\addwater0906\\val\\"
data = []
i = 0
for line in open("G:\\addwater0906\\val.txt", "r"):
data.append(line)
for img in data:
filename = img.split(",")
print(filename[0])
a = cv2.imread("G:\\addwater0906\\"+filename[0])
b = os.path.split(filename[0])[-1]
cv2.imwrite(val_path + b, a)
i += 1
print("一共验证集val多少图片:", i)
parse = argparse.ArgumentParser()
parse.add_argument('--config', dest='config', type=str, default='./configs/bisenetv2_city.py',)
parse.add_argument('--weight-path', type=str, default='./waterv3_model_final.pth',)
parse.add_argument('--img_path', dest='img_path', type=str, default= val_path,)
args = parse.parse_args()
cfg = set_cfg_from_file(args.config)
palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8)
net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='pred')
net.load_state_dict(torch.load(args.weight_path, map_location='cpu'), strict=False)
net.eval()
to_tensor = T.ToTensor(
mean=(0.3257, 0.3690, 0.3223),
std=(0.2112, 0.2148, 0.2115),
)
dir_path = args.img_path
for file_name in tqdm(os.listdir(dir_path)):
path = val_path + file_name
im = cv2.imread(path)[:, :, ::-1]
im = to_tensor(dict(im=im, lb=None))['im'].unsqueeze(0)
t1 = time.time()
out = net(im).squeeze().detach().cpu().numpy()
cv2.imwrite('G:\\addwater0906\\val_mask\\' + file_name, out)
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import cv2
def create_pascal_label_colormap():
colormap = np.zeros((256, 3), dtype=int)
ind = np.arange(256, dtype=int)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= ((ind >> channel) & 1) << shift
ind >>= 3
return colormap
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError('Expect 2-D input label')
colormap = create_pascal_label_colormap()
if np.max(label) >= len(colormap):
raise ValueError('label value too large.')
return colormap[label]
def vis_segmentation(image, seg_map):
"""
输入图片和分割 mask 的可视化.
"""
plt.figure(figsize=(15, 5))
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
plt.subplot(grid_spec[0])
plt.imshow(image)
plt.axis('off')
plt.title('input image')
plt.subplot(grid_spec[1])
seg_image = label_to_color_image(seg_map).astype(np.uint8)
plt.imshow(seg_image)
plt.axis('off')
plt.title('segmentation map')
plt.subplot(grid_spec[2])
plt.imshow(image)
plt.imshow(seg_image, alpha=0.5)
plt.axis('off')
plt.title('segmentation overlay')
unique_labels = np.unique(seg_map)
ax = plt.subplot(grid_spec[3])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0)
plt.grid('off')
LABEL_NAMES = np.asarray(['background', 'water'])
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
img_path = 'G:\\addwater0906\\val\\'
png_path = 'G:\\addwater0906\\val_mask\\'
i = 0
for img in os.listdir(img_path):
print(img)
imgfile = img_path + img
pngfile = png_path + img
new_path = "G:\\addwater0906\\pltsave\\" + img
img = cv2.imread(imgfile, 1)
img = img[:, :, ::-1]
seg_map = cv2.imread(pngfile, 0)
vis_segmentation(img, seg_map)
plt.savefig(new_path)
plt.close()
i +=1
print('可视化总数量:',i)
print('Done.')
3.把原图mask和推理出来的mask进行像素级别的比对相似度
import os
from PIL import Image
import time
from tqdm import tqdm
def pixel_equal(image1, image2, x, y):
"""
判断两个像素是否相同
:param image1: 图片1
:param image2: 图片2
:param x: 位置x
:param y: 位置y
:return: 像素是否相同
"""
piex1 = image1.load()[x, y]
piex2 = image2.load()[x, y]
threshold = 1
if abs(piex1[0] - piex2[0]) < threshold and abs(piex1[1]- piex2[1]) < threshold and abs(piex1[2] - piex2[2]) < threshold:
return True
else:
return False
def compare(image1, image2):
"""
进行比较
:param image1:图片1
:param image2: 图片2
:return:
"""
left = 0
right_num = 0
false_num = 0
all_num = 0
for i in range(left, image1.size[0]):
for j in range(image1.size[1]):
if pixel_equal(image1, image2, i, j):
right_num += 1
else:
false_num += 1
all_num += 1
same_rate = right_num / all_num
nosame_rate = false_num / all_num
return same_rate,nosame_rate
if __name__ == "__main__":
img_ori_path=r"F:\1207garbage_classification\0505_garbage_optimize\val_mask_ori"
img_demo_path = r"F:\1207garbage_classification\0505_garbage_optimize\val_mask_demo"
sum_same_rate = 0
sum_nosame_rate =0
n = len(os.listdir(img_ori_path))
for image1 in tqdm(os.listdir(img_ori_path)):
print(image1)
image1_path = os.path.join(img_ori_path,image1)
image2_path = os.path.join(img_demo_path, image1)
image1 = Image.open(image1_path).convert("RGB")
image2 = Image.open(image2_path).convert("RGB")
same_rate,nosame_rate = compare(image1, image2)
print("same_rate:%.2f,nosame_rate:%.2f" % (same_rate, nosame_rate))
sum_same_rate += same_rate
sum_nosame_rate += nosame_rate
avg_same_rate = sum_same_rate / n
avg_nosame_rate = sum_nosame_rate / n
print("avg_same_rate:%.2f,avg_nosame_rate:%.2f" %(avg_same_rate,avg_nosame_rate))
对比的结果,如图所示:
|