前言
源码: YOLOv5源码. 导航: 【YOLOV5-5.0 源码讲解】整体项目文件导航.
\qquad
这个文件都是一些画图函数,是一个工具类。代码本身逻辑并不难,主要是一些包的函数可能大家没见过。这里我总结了一些画图包的一些常见的画图函数: 【Opencv、ImageDraw、Matplotlib、Pandas、Seaborn】一些常见的画图函数。如果在下面代码中碰到不太熟的画图函数,可以查一下我的笔记或者自己百度一下。
0、导入需要的包和基本配置
import glob
import math
import os
from copy import copy
from pathlib import Path
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness
matplotlib.rc('font', **{'size': 11})
matplotlib.use('Agg')
1、Colors
\qquad
这是一个颜色类,用于选择相应的颜色,比如画框线的颜色,字体颜色等等。
Colors类代码:
class Colors:
def __init__(self):
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb('#' + c) for c in hex]
self.n = len(self.palette)
def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c
@staticmethod
def hex2rgb(h):
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
colors = Colors()
使用起来也是比较简单只要直接输入颜色序号即可:
2、plot_one_box、plot_one_box_PIL
\qquad
plot_one_box 和 plot_one_box_PIL 这两个函数都是用于在原图im上画一个bounding box,区别在于前者使用的是opencv画框,后者使用PIL画框。这两个函数的功能其实是重复的,其实我们用的比较多的是plot_one_box函数,plot_one_box_PIL几乎没用,了解下即可。
2.1、plot_one_box
\qquad
这个函数通常用在检测nms后(detect.py中)将最终的预测bounding box在原图中画出来,不过这个函数依次只能画一个框框。
plot_one_box函数代码:
def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
"""一般会用在detect.py中在nms之后变量每一个预测框,再将每个预测框画在原图上
使用opencv在原图im上画一个bounding box
:params x: 预测得到的bounding box [x1 y1 x2 y2]
:params im: 原图 要将bounding box画在这个图上 array
:params color: bounding box线的颜色
:params labels: 标签上的框框信息 类别 + score
:params line_thickness: bounding box的线宽
"""
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA)
cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
这个函数一般会用在detect.py中在nms之后变量每一个预测框,再将每个预测框画在原图上如: 效果如下所示:
2.2、plot_one_box_PIL(没用到)
\qquad
这个函数是用PIL在原图中画一个框,作用和plot_one_box一样,而且我们一般都是用plot_one_box而不用这个函数,所以了解下即可。
plot_one_box_PIL函数代码:
def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=None):
"""
使用PIL在原图im上画一个bounding box
:params box: 预测得到的bounding box [x1 y1 x2 y2]
:params im: 原图 要将bounding box画在这个图上 array
:params color: bounding box线的颜色
:params label: 标签上的bounding box框框信息 类别 + score
:params line_thickness: bounding box的线宽
"""
im = Image.fromarray(im)
draw = ImageDraw.Draw(im)
line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
draw.rectangle(box, width=line_thickness, outline=color)
if label:
font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
txt_width, txt_height = font.getsize(label)
draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
return np.asarray(im)
3、plot_wh_methods(没用到)
\qquad
这个函数主要是用来比较
y
a
=
e
x
y_a = e^x
ya?=ex 、
y
b
=
(
2
?
s
i
g
m
o
i
d
(
x
)
)
2
y_b = (2 * sigmoid(x))^2
yb?=(2?sigmoid(x))2 、
y
c
=
(
2
?
s
i
g
m
o
i
d
(
x
)
)
1.6
y_c = (2 * sigmoid(x))^{1.6}
yc?=(2?sigmoid(x))1.6 这三个函数图形的。其中
y
a
y_a
ya? 是普通的yolo method,
y
b
y_b
yb? 和
y
c
y_c
yc?是作者提出的powe method方法。在 https://github.com/ultralytics/yolov3/issues/168.中,作者由讨论过这个issue。作者在实验中发现使用原来的yolo method损失计算有时候会突然迅速走向无限None值, 而power method方式计算wh损失下降会比较平稳。最后实验证明
y
b
y_b
yb? 是最好的wh损失计算方式, yolov5-5.0的wh损失计算代码用的就是
y
b
y_b
yb? 计算方式 如:
yolo.py: loss.py: plot_wh_methods函数代码:
def plot_wh_methods():
"""没用到
比较ya=e^x、yb=(2 * sigmoid(x))^2 以及 yc=(2 * sigmoid(x))^1.6 三个图形
wh损失计算的方式ya、yb、yc三种 ya: yolo method yb/yc: power method
实验发现使用原来的yolo method损失计算有时候会突然迅速走向无限None值, 而power method方式计算wh损失下降会比较平稳
最后实验证明yb是最好的wh损失计算方式, yolov5-5.0的wh损失计算代码用的就是yb计算方式
Compares the two methods for width-height anchor multiplication
https://github.com/ultralytics/yolov3/issues/168
"""
x = np.arange(-4.0, 4.0, .1)
ya = np.exp(x)
yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
fig = plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(x, ya, '.-', label='YOLOv3')
plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
plt.xlim(left=-4, right=4)
plt.ylim(bottom=0, top=6)
plt.xlabel('input')
plt.ylabel('output')
plt.grid()
plt.legend()
fig.savefig('comparison.png', dpi=200)
\qquad
其实这个函数倒不是特别重要,只是可视化一下这三个函数,看看他们的区别,在代码中也没调用过这个函数。但是了解这种新型 wh 损失计算的方式(Power Method)还是很有必要的。
4、output_to_target、plot_images
\qquad
这两个函数其实也是对检测到的目标格式进行处理(output_to_target)然后再将其画框显示在原图上(plot_images)。不过这两个函数是用在test.py中的,针对的也不再是一张图片一个框,而是整个batch中的所有框。而且plot_images会将整个batch的图片都画在一张大图mosaic中,画不下的就删除。这些都是plot_images函数和plot_one_box的区别。
4.1、output_to_target
\qquad
这个函数是用于将经过nms后的output [num_obj,x1y1x2y2+conf+cls] -> [num_obj,batch_id+class+xywh+conf]。并不涉及画图操作,而是转化predict的格式,通常放在画图操作plot_images之前。
output_to_target函数代码:
def output_to_target(output):
"""用在test.py中进行绘制前3个batch的预测框predictions 因为只有predictions需要修改格式 target是不需要修改格式的
将经过nms后的output [num_obj,x1y1x2y2+conf+cls] -> [num_obj, batch_id+class+x+y+w+h+conf] 转变格式
以便在plot_images中进行绘图 + 显示label
Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
:params output: list{tensor(8)}分别对应着当前batch的8(batch_size)张图片做完nms后的结果
list中每个tensor[n, 6] n表示当前图片检测到的目标个数 6=x1y1x2y2+conf+cls
:return np.array(targets): [num_targets, batch_id+class+xywh+conf] 其中num_targets为当前batch中所有检测到目标框的个数
"""
targets = []
for i, o in enumerate(output):
for *box, conf, cls in o.cpu().numpy():
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
return np.array(targets)
4.1、plot_images
\qquad
这个函数是用来绘制一个batch的所有图片的框框(真实框或预测框)。使用在test.py中,且在output_to_target函数之后。而且这个函数是将一个batch的图片都放在一个大图mosaic上面,放不下删除。
plot_images函数代码:
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
"""用在test.py中进行绘制前3个batch的ground truth和预测框predictions(两个图) 一起保存 或者train.py中
将整个batch的labels都画在这个batch的images上
Plot image grid with labels
:params images: 当前batch的所有图片 Tensor [batch_size, 3, h, w] 且图片都是归一化后的
:params targets: 直接来自target: Tensor[num_target, img_index+class+xywh] [num_target, 6]
来自output_to_target: Tensor[num_pred, batch_id+class+xywh+conf] [num_pred, 7]
:params paths: tuple 当前batch中所有图片的地址
如: '..\\datasets\\coco128\\images\\train2017\\000000000315.jpg'
:params fname: 最终保存的文件路径 + 名字 runs\train\exp8\train_batch2.jpg
:params names: 传入的类名 从class index可以相应的key值 但是默认是None 只显示class index不显示类名
:params max_size: 图片的最大尺寸640 如果images有图片的大小(w/h)大于640则需要resize 如果都是小于640则不需要resize
:params max_subplots: 最大子图个数 16
:params mosaic: 一张大图 最多可以显示max_subplots张图片 将总多的图片(包括各自的label框框)一起贴在一起显示
mosaic每张图片的左上方还会显示当前图片的名字 最好以fname为名保存起来
"""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()
if np.max(images[0]) <= 1:
images *= 255
tl = 3
tf = max(tl - 1, 1)
bs, _, h, w = images.shape
bs = min(bs, max_subplots)
ns = np.ceil(bs ** 0.5)
scale_factor = max_size / max(h, w)
if scale_factor < 1:
h = math.ceil(scale_factor * h)
w = math.ceil(scale_factor * w)
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
for i, img in enumerate(images):
if i == max_subplots:
break
block_x = int(w * (i // ns))
block_y = int(h * (i % ns))
img = img.transpose(1, 2, 0)
if scale_factor < 1:
img = cv2.resize(img, (w, h))
mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
if len(targets) > 0:
image_targets = targets[targets[:, 0] == i]
boxes = xywh2xyxy(image_targets[:, 2:6]).T
classes = image_targets[:, 1].astype('int')
labels = image_targets.shape[1] == 6
conf = None if labels else image_targets[:, 6]
if boxes.shape[1]:
if boxes.max() <= 1.01:
boxes[[0, 2]] *= w
boxes[[1, 3]] *= h
elif scale_factor < 1:
boxes *= scale_factor
boxes[[0, 2]] += block_x
boxes[[1, 3]] += block_y
for j, box in enumerate(boxes.T):
cls = int(classes[j])
color = colors(cls)
cls = names[cls] if names else cls
if labels or conf[j] > 0.25:
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
if paths:
label = Path(paths[i]).name[:40]
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0,
tl / 3, [220, 220, 220], thickness=tf, lineType=cv2.LINE_AA)
cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
if fname:
r = min(1280. / max(h, w) / ns, 1.0)
mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
Image.fromarray(mosaic).save(fname)
return mosaic
这两个函数都是用在test.py函数中的:
用在train.py:
执行效果test.py(target): 执行效果test.py(predict):
5、plot_lr_scheduler
\qquad
这个函数是用来画出在训练过程中每个epoch的学习率变化情况。
plot_lr_scheduler函数代码:
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
"""用在train.py中学习率设置后可视化一下
Plot LR simulating training for full epochs
:params optimizer: 优化器
:params scheduler: 策略调整器
:params epochs: x
:params save_dir: lr图片 保存地址
"""
optimizer, scheduler = copy(optimizer), copy(scheduler)
y = []
for _ in range(epochs):
scheduler.step()
y.append(optimizer.param_groups[0]['lr'])
plt.plot(y, '.-', label='LR')
plt.xlabel('epoch')
plt.ylabel('LR')
plt.grid()
plt.xlim(0, epochs)
plt.ylim(0)
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
plt.close()
通常用于train.py中学习率设置后可视化一下: 执行效果:
6、hist2d、plot_test_txt、plot_targets_txt
6.1、hist2d
\qquad
这个函数是使用numpy工具画出2d直方图。不过好像用的不多,大多数都是调用工具包封装好的2d直方图方法,所以这个包其实只在plot_evolution函数和plot_test_txt函数中使用。
hist2d函数代码:
def hist2d(x, y, n=100):
"""用在plot_evolution和plot_test_txt
使用numpy画出2d直方图
2d histogram used in labels.png and evolve.png
"""
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
return np.log(hist[xidx, yidx])
6.2、plot_test_txt
\qquad
这个函数是利用test.py中生成的test.txt文件(或者其他的*.txt文件),画出其xy直方图和xy双直方图。其实这个plot_test_txt这个函数作者并没有使用它,但是我们确实可以自己写一个脚本来执行这个函数,观察一下预测到的所有图像的目标情况(wh分布)。
plot_test_txt函数代码:
def plot_test_txt(test_dir='test.txt'):
"""可以自己写个脚本执行test.txt文件
利用test.txt xyxy画出其直方图和双直方图
Plot test.txt histograms
:params test_dir: test.py中生成的一些 save_dir/labels中的txt文件
"""
x = np.loadtxt(test_dir, dtype=np.float32)
box = xyxy2xywh(x[:, 2:6])
cx, cy = box[:, 0], box[:, 1]
fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
ax.set_aspect('equal')
plt.savefig('hist2d.png', dpi=300)
fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
ax[0].hist(cx, bins=600)
ax[1].hist(cy, bins=600)
plt.savefig('hist1d.png', dpi=200)
自己写个脚本使用: hist1d.png效果(横坐标分别是w和h,纵坐标是个数): hist2d.png效果(横坐标是x纵坐标是y):
6.3、plot_targets_txt(没用到)
\qquad
这个函数是利用targets.txt(真实框的xywh)画出其直方图。但是并没有使用这个函数,而且细心的可以发现这个函数和之后的plot_labels函数是重复的。所以这个函数就随便看看吧。
plot_targets_txt函数代码:
def plot_targets_txt():
"""没用到 和plot_labels作用重复
利用targets.txt xywh画出其直方图
Plot targets.txt histograms
"""
x = np.loadtxt('targets.txt', dtype=np.float32).T
s = ['x targets', 'y targets', 'width targets', 'height targets']
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
ax = ax.ravel()
for i in range(4):
ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
ax[i].legend()
ax[i].set_title(s[i])
plt.savefig('targets.jpg', dpi=200)
7、plot_labels
\qquad
这个函数是根据从datasets中取到的labels,分析其类别分布,画出labels相关直方图信息。最终会生成labels_correlogram.jpg和labels.jpg两张图片。labels_correlogram.jpg包含所有标签的 x,y,w,h的多变量联合分布直方图:查看两个或两个以上变量之间两两相互关系的可视化形式(里面包含x、y、w、h两两之间的分布直方图)。而labels.jpg包含了:ax[0]画出classes的各个类的分布直方图,ax[1]画出所有的真实框;ax[2]画出xy直方图;ax[3]画出wh直方图。
plot_labels函数代码:
def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
"""通常用在train.py中 加载数据datasets和labels后 对labels进行可视化 分析labels信息
plot dataset labels 生成labels_correlogram.jpg和labels.jpg 画出数据集的labels相关直方图信息
:params labels: 数据集的全部真实框标签 (num_targets, class+xywh) (929, 5)
:params names: 数据集的所有类别名
:params save_dir: runs\train\exp21
:params loggers: 日志对象
"""
print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose()
nc = int(c.max() + 1)
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()
matplotlib.use('svg')
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
else:
ax[0].set_xlabel('classes')
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
labels[:, 1:3] = 0.5
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
for cls, *box in labels[:1000]:
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls))
ax[1].imshow(img)
ax[1].axis('off')
for a in [0, 1, 2, 3]:
for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False)
plt.savefig(save_dir / 'labels.jpg', dpi=200)
matplotlib.use('Agg')
plt.close()
for k, v in loggers.items() or {}:
if k == 'wandb' and v:
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
这个函数一般会用在train.py的载入数据datasets和labels后,统计分析labels相关分布信息: labels_correlogram.jpg执行效果:
labels.jpg执行效果:
8、plot_evolution
\qquad
这个函数用于超参进化的最后阶段,负责画出每个超参进化的结果。最终会生成evolve.png,里面是每个超参的进化情况以及相对应的mAP的散点图,每5个超参一行,并且会用 ‘+’ 标注好最佳mAP对应的超参值。且每个散点图会输出最佳mAP对应的每个超参的最佳超参。而且plot_evolution也调用了上面的hist2d函数。
plot_evolution函数代码:
def plot_evolution(yaml_file='data/hyp.finetune.yaml', save_dir=Path('')):
"""用在train.py的超参进化算法后,输出参超进化的结果
超参进化在每一轮都会产生一系列的进化后的超参(存在yaml_file) 以及每一轮都会算出当前轮次的7个指标(evolve.txt)
这个函数要做的就是把每个超参在所有轮次变化的值和maps以散点图的形式显示出来,并标出最大的map对应的超参值 一个超参一个散点图
:params yaml_file: 'runs/train/evolve/hyp_evolved.yaml'
"""
with open(yaml_file) as f:
hyp = yaml.safe_load(f)
x = np.loadtxt('evolve.txt', ndmin=2)
f = fitness(x)
plt.figure(figsize=(10, 12), tight_layout=True)
matplotlib.rc('font', **{'size': 8})
for i, (k, v) in enumerate(hyp.items()):
y = x[:, i + 7]
mu = y[f.argmax()]
plt.subplot(6, 5, i + 1)
plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
plt.plot(mu, f.max(), 'k+', markersize=15)
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9})
if i % 5 != 0:
plt.yticks([])
print('%15s: %.3g' % (k, mu))
plt.savefig(save_dir / 'evolve.png', dpi=200)
print('\nPlot saved as evolve.png')
这个函数通常会用在train.py的超参进化算法后,输出参超进化的结果: 函数执行效果evolve.png:
9、plot_results、plot_results_overlay、butter_lowpass_filtfilt
\qquad
这三个函数都是用来对result.txt中的各项指标进行可视化的,但是plot_results是将一个指标画在折线图上(共10个折线图),而 plot_results_overlay要做的是将原先的10个显示的指标,两个两个进行对比画在同一个折线图上(共5个折线图)。最后的butter_lowpass_filtfilt函数
9.1、plot_results
\qquad
这个函数是将训练后的结果 results.txt 中相关的训练指标画出来。
\qquad
result.txt中一行的元素分别有:当前epoch/总epochs-1 、当前的显存容量mem、box回归损失、obj置信度损失、cls分类损失、训练总损失、真实目标数量num_target、图片尺寸img_shape、Precision、Recall、map@0.5、map@0.5:0.95、测试box回归损失、测试obj置信度损、测试cls分类损失。
\qquad
在result.txt中画出的指标有:训练回归损失Box、训练置信度损失Objectness、训练分类损失Classification、Precision、Recall、验证回归损失 val Box、验证置信度损失val Objectness、验证分类损失val Classification、mAP@0.5、mAP@0.5:0.95。
plot_results函数代码:
def plot_results(start=0, stop=0, bucket='', id=(), save_dir=''):
"""'用在训练结束, 对训练结果进行可视化
画出训练完的 results.txt Plot training 'results*.txt' 最终生成results.png
:params start: 读取数据的开始epoch 因为result.txt的数据是一个epoch一行的
:params stop: 读取数据的结束epoch
:params bucket: 是否需要从googleapis中下载results*.txt文件
:params id: 需要从googleapis中下载的results + id.txt 默认为空
:params save_dir: 'runs\train\exp22'
"""
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
ax = ax.ravel()
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
if bucket:
files = ['results%g.txt' % x for x in id]
c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
os.system(c)
else:
files = list(Path(save_dir).glob('results*.txt'))
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
for fi, f in enumerate(files):
try:
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1]
x = range(start, min(stop, n) if stop else n)
for i in range(10):
y = results[i, x]
if i in [0, 1, 2, 5, 6, 7]:
y[y == 0] = np.nan
ax[i].plot(x, y, marker='.', linewidth=2, markersize=8)
ax[i].set_title(s[i])
except Exception as e:
print('Warning: Plotting error for %s; %s' % (f, e))
fig.savefig(Path(save_dir) / 'results1.png', dpi=200)
这个函数会用在train.py训练结束后对训练结果进行可视化: 执行结果result1.png:
9.2、plot_results_overlay
\qquad
这个函数还是将result.txt文件中的各项指标进行可视化,不过将原先的10个折线图减为5个折线图, train和val两两相对比。 plot_results_overlay函数代码:
def plot_results_overlay(start=0, stop=0):
"""可以用在train.py或者自写一个文件
画出训练完的 results.txt Plot training 'results*.txt' 而且将原先的10个折线图缩减为5个折线图, train和val相对比
Plot training 'results*.txt', overlaying train and val losses
"""
s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95']
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1']
for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1]
x = range(start, min(stop, n) if stop else n)
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
ax = ax.ravel()
for i in range(5):
for j in [i, i + 5]:
y = results[j, x]
ax[i].plot(x, y, marker='.', label=s[j])
ax[i].set_title(t[i])
ax[i].legend()
ax[i].set_ylabel(f) if i == 0 else None
fig.savefig(f.replace('.txt', '.png'), dpi=200)
这个函数可以放在plot_results下面,也可以自己写一个: 执行结果result.py:
9.3、butter_lowpass_filtfilt
\qquad
这个函数是为了防止在训练时有些指标非常的抖动,导致画出来很难看,比如下面这种情况: 红色部分真实值非常抖动,画出来很难看,那我们就对它进行一个平滑处理,取它的一个近似值。
butter_lowpass_filtfilt函数代码:
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
"""
当data值抖动太大, 就取data的平滑曲线
"""
from scipy.signal import butter, filtfilt
def butter_lowpass(cutoff, fs, order):
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
return butter(order, normal_cutoff, btype='low', analog=False)
b, a = butter_lowpass(cutoff, fs, order=order)
return filtfilt(b, a, data)
\qquad
这部分的代码是用在上面plot_results_overlay函数里面的,不过它是注释掉的,如果发现自己训练的结果发生上面的那种抖动情况,大家可以打开注释,或者任意调用这个函数达到一种平滑效果。这个函数代码我是每有看的,感兴趣可以自己读读,就几行应该不难。
10、feature_visualization
\qquad
这个函数是用来可视化feature map的,而且可以实现可视化网络中任意一层的feature map。
函数代码:
def feature_visualization(x, module_type, stage, n=64):
"""用在yolo.py的Model类中的forward_once函数中 自行选择任意层进行可视化该层feature map
可视化feature map(模型任意层都可以用)
:params x: Features map [bs, channels, height, width]
:params module_type: Module type
:params stage: Module stage within model
:params n: Maximum number of feature maps to plot
"""
batch, channels, height, width = x.shape
if height > 1 and width > 1:
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name)
save_dir.mkdir(parents=True, exist_ok=True)
plt.figure(tight_layout=True)
blocks = torch.chunk(x, channels, dim=1)
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
plt.imshow(feature)
f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)
通常这个函数是把他放在yolo.py的Model类中的forward_once函数中: 自己可以在if中选择要查看的任意一层feature map。
原图: 执行效果:
11、plot_study_txt(没用到)、profile_idetection(没用到)
\qquad
剩下的这两个函数是没什么用的,plot_study_txt是test.py中opt.task == 'study’时评估yolov5系列和yolov3-spp各个模型在各个尺度下的指标并可视化,但是其实我们几乎用不到这里。另外一个函数profile_idetection完全没用到。所以这两个函数不看也可以。
def plot_study_txt(path='', x=None):
"""没用到
Plot study.txt generated by test.py
"""
plot2 = False
if plot2:
ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
for f in sorted(Path(path).glob('study*.txt')):
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
x = np.arange(y.shape[1]) if x is None else np.array(x)
if plot2:
s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
for i in range(7):
ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
ax[i].set_title(s[i])
j = y[3].argmax() + 1
ax2.plot(y[5, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
ax2.grid(alpha=0.2)
ax2.set_yticks(np.arange(20, 60, 5))
ax2.set_xlim(0, 57)
ax2.set_ylim(30, 55)
ax2.set_xlabel('GPU Speed (ms/img)')
ax2.set_ylabel('COCO AP val')
ax2.legend(loc='lower right')
plt.savefig(str(Path(path).name) + '.png', dpi=300)
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
"""没用到
Plot iDetection '*.txt' per-image logs
"""
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
files = list(Path(save_dir).glob('frames*.txt'))
for fi, f in enumerate(files):
try:
results = np.loadtxt(f, ndmin=2).T[:, 90:-30]
n = results.shape[1]
x = np.arange(start, min(stop, n) if stop else n)
results = results[:, x]
t = (results[0] - results[0].min())
results[0] = x
for i, a in enumerate(ax):
if i < len(results):
label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
a.set_title(s[i])
a.set_xlabel('time (s)')
for side in ['top', 'right']:
a.spines[side].set_visible(False)
else:
a.remove()
except Exception as e:
print('Warning: Plotting error for %s; %s' % (f, e))
ax[1].legend()
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
总结
\qquad
这个文件的代码主要是一些画图用的工具函数,和我们目标检测的主要流程其实没用什么关系。比较重要的函数有:plot_one_box、output_to_target、plot_images、plot_labels、plot_evolution、plot_results、plot_results_overlay、feature_visualization等。
–2021.08.02 22:14
|