如有错误,恳请指出。
这篇博客的代码来着博主:太阳花的小绿豆,具体的解释说明可以见参考资料,这里只贴上代码留作笔记使用。
ps:参考资料解释得非常的详细
1. kmeans聚类中心
import numpy as np
from matplotlib import pyplot as plt
np.random.seed(0)
colors = np.array(['blue', 'black'])
def plot_clusters(data, cls, clusters, title=""):
if cls is None:
c = [colors[0]] * data.shape[0]
else:
c = colors[cls].tolist()
plt.scatter(data[:, 0], data[:, 1], c=c)
for i, clus in enumerate(clusters):
plt.scatter(clus[0], clus[1], c='gold', marker='*', s=150)
plt.title(title)
plt.show()
plt.close()
def distances(data, clusters):
xy1 = data[:, None]
xy2 = clusters[None]
d = np.sum(np.power(xy2 - xy1, 2), axis=-1)
return d
def k_means(data, k, dist=np.mean):
"""
k-means methods
Args:
data: 需要聚类的data
k: 簇数(聚成几类)
dist: 更新簇坐标的方法
"""
data_number = data.shape[0]
last_nearest = np.zeros((data_number,))
clusters = data[np.random.choice(data_number, k, replace=False)]
print(f"random cluster: \n {clusters}")
plot_clusters(data, None, clusters, "random clusters")
step = 0
while True:
d = distances(data, clusters)
current_nearest = np.argmin(d, axis=1)
plot_clusters(data, current_nearest, clusters, f"step {step}")
if (last_nearest == current_nearest).all():
break
for cluster in range(k):
clusters[cluster] = dist(data[current_nearest == cluster], axis=0)
last_nearest = current_nearest
step += 1
return clusters
def main():
x1, y1 = [np.random.normal(loc=1., size=150) for _ in range(2)]
x2, y2 = [np.random.normal(loc=5., size=150) for _ in range(2)]
x = np.concatenate([x1, x2])
y = np.concatenate([y1, y2])
plt.scatter(x, y, c='blue')
plt.title("initial data")
plt.show()
plt.close()
clusters = k_means(np.concatenate([x[:, None], y[:, None]], axis=-1), k=2)
print(f"k-means fluster: \n {clusters}")
if __name__ == '__main__':
main()
2. kmeans聚类anchor
import numpy as np
def wh_iou(wh1, wh2):
wh1 = wh1[:, None]
wh2 = wh2[None]
inter = np.minimum(wh1, wh2).prod(2)
return inter / (wh1.prod(2) + wh2.prod(2) - inter)
def k_means(boxes, k, dist=np.median):
"""
yolo k-means methods
refer: https://github.com/qqwweee/keras-yolo3/blob/master/kmeans.py
Args:
boxes: 需要聚类的bboxes
k: 簇数(聚成几类)
dist: 更新簇坐标的方法(默认使用中位数,比均值效果略好)
"""
box_number = boxes.shape[0]
last_nearest = np.zeros((box_number,))
clusters = boxes[np.random.choice(box_number, k, replace=False)]
while True:
distances = 1 - wh_iou(boxes, clusters)
current_nearest = np.argmin(distances, axis=1)
if (last_nearest == current_nearest).all():
break
for cluster in range(k):
clusters[cluster] = dist(boxes[current_nearest == cluster], axis=0)
last_nearest = current_nearest
return clusters
3. kmeans+遗传算法聚类anchor
首先获取数据集的ground ture,获取样本尺寸,然后根据这些样本来对3个特征层分别聚类3个anchor,所以一共就是输出9个anchor,将anchor按面积的大小由小到大排序,依次分配给3个预测特征层即可。
其中聚类的过程中,配合遗传算法,使得anchor存在小幅度的变动,增强了多样性与泛化性。
- 读取数据集中的gtbox参考代码,read_voc.py:
import os
from tqdm import tqdm
from lxml import etree
class VOCDataSet(object):
def __init__(self, voc_root, year="2012", txt_name: str = "train.txt"):
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.annotations_root = os.path.join(self.root, "Annotations")
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines() if len(line.strip()) > 0]
assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
for xml_path in self.xml_list:
assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)
def __len__(self):
return len(self.xml_list)
def parse_xml_to_dict(self, xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if len(xml) == 0:
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
def get_info(self):
im_wh_list = []
boxes_wh_list = []
for xml_path in tqdm(self.xml_list, desc="read data info."):
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
im_height = int(data["size"]["height"])
im_width = int(data["size"]["width"])
wh = []
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
wh.append([(xmax - xmin) / im_width, (ymax - ymin) / im_height])
if len(wh) == 0:
continue
im_wh_list.append([im_width, im_height])
boxes_wh_list.append(wh)
return im_wh_list, boxes_wh_list
import random
import numpy as np
from tqdm import tqdm
from scipy.cluster.vq import kmeans
from read_voc import VOCDataSet
from yolo_kmeans import k_means, wh_iou
def anchor_fitness(k: np.ndarray, wh: np.ndarray, thr: float):
r = wh[:, None] / k[None]
x = np.minimum(r, 1. / r).min(2)
best = x.max(1)
f = (best * (best > thr).astype(np.float32)).mean()
bpr = (best > thr).astype(np.float32).mean()
return f, bpr
def main(img_size=512, n=9, thr=0.25, gen=1000):
dataset = VOCDataSet(voc_root="E:\学习\机器学习\数据集\VOC2012", year="2012", txt_name="train.txt")
im_wh, boxes_wh = dataset.get_info()
im_wh = np.array(im_wh, dtype=np.float32)
shapes = img_size * im_wh / im_wh.max(1, keepdims=True)
wh0 = np.concatenate([l * s for s, l in zip(shapes, boxes_wh)])
i = (wh0 < 3.0).any(1).sum()
if i:
print(f'WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
wh = wh0[(wh0 >= 2.0).any(1)]
k = k_means(wh, n)
k = k[np.argsort(k.prod(1))]
f, bpr = anchor_fitness(k, wh, thr)
print("kmeans: " + " ".join([f"[{int(i[0])}, {int(i[1])}]" for i in k]))
print(f"fitness: {f:.5f}, best possible recall: {bpr:.5f}")
npr = np.random
f, sh, mp, s = anchor_fitness(k, wh, thr)[0], k.shape, 0.9, 0.1
pbar = tqdm(range(gen), desc=f'Evolving anchors with Genetic Algorithm:')
for _ in pbar:
v = np.ones(sh)
while (v == 1).all():
v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
kg = (k.copy() * v).clip(min=2.0)
fg, bpr = anchor_fitness(kg, wh, thr)
if fg > f:
f, k = fg, kg.copy()
pbar.desc = f'Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
k = k[np.argsort(k.prod(1))]
print("genetic: " + " ".join([f"[{int(i[0])}, {int(i[1])}]" for i in k]))
print(f"fitness: {f:.5f}, best possible recall: {bpr:.5f}")
if __name__ == "__main__":
main()
参考资料:
- 使用k-means聚类anchors:https://blog.csdn.net/qq_37541097/article/details/119647026
|