请参考:YOLOv5(ultralytics) 训练自己的数据集,VOC2007为例
Fire Dataset:
https://download.csdn.net/download/W1995S/20666141
将数据下载在yolov5/my_data文件夹下,进行文件夹合并,弄成如下: 1、ImageSets/Main文件夹下生成train.txt,val.txt,test.txt和trainval.txt四个文件(存放图片名字): my_data目录下,创建split_train_val.py 文件
import os
import random
import argparse
import time
parser = argparse.ArgumentParser()
parser.add_argument('--xml_path', default='Annotations', type=str, help='input xml label path')
parser.add_argument('--txt_path', default='ImageSets/Main', type=str, help='output txt label path')
opt = parser.parse_args()
trainval_percent = 1.0
train_percent = 0.9
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
total_xml = os.listdir(xmlfilepath)
if not os.path.exists(txtsavepath):
os.makedirs(txtsavepath)
num = len(total_xml)
list_index = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list_index, tv)
train = random.sample(trainval, tr)
file_trainval = open(txtsavepath + '/trainval.txt', 'w')
file_test = open(txtsavepath + '/test.txt', 'w')
file_train = open(txtsavepath + '/train.txt', 'w')
file_val = open(txtsavepath + '/val.txt', 'w')
for i in list_index:
name = total_xml[i][:-4] + '\n'
if i in trainval:
file_trainval.write(name)
if i in train:
file_train.write(name)
else:
file_val.write(name)
else:
file_test.write(name)
file_trainval.close()
file_train.close()
file_val.close()
file_test.close()
2、创建yolo格式的label,my_data目录下,创建xml2yolo_label.py 文件,会产生labels文件夹:
import os
import time
import sys
from xml.etree import ElementTree
from xml.etree.ElementTree import Element, SubElement
from lxml import etree
import codecs
import cv2
from glob import glob
XML_EXT = '.xml'
ENCODE_METHOD = 'utf-8'
class PascalVocReader:
def __init__(self, filepath):
self.shapes = []
self.filepath = filepath
self.verified = False
try:
self.parseXML()
except:
pass
def getShapes(self):
return self.shapes
def addShape(self, label, bndbox, filename, difficult):
xmin = int(bndbox.find('xmin').text)
ymin = int(bndbox.find('ymin').text)
xmax = int(bndbox.find('xmax').text)
ymax = int(bndbox.find('ymax').text)
points = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]
self.shapes.append((label, points, filename, difficult))
def parseXML(self):
assert self.filepath.endswith(XML_EXT), "Unsupport file format"
parser = etree.XMLParser(encoding=ENCODE_METHOD)
xmltree = ElementTree.parse(self.filepath, parser=parser).getroot()
filename = xmltree.find('filename').text
path = xmltree.find('path').text
try:
verified = xmltree.attrib['verified']
if verified == 'yes':
self.verified = True
except KeyError:
self.verified = False
for object_iter in xmltree.findall('object'):
bndbox = object_iter.find("bndbox")
label = object_iter.find('name').text
difficult = False
if object_iter.find('difficult') is not None:
difficult = bool(int(object_iter.find('difficult').text))
self.addShape(label, bndbox, path, difficult)
return True
classes = dict()
num_classes = 0
parentpath = os.getcwd()
addxmlpath = parentpath + '/Annotations'
addimgpath = parentpath + '/JPEGImages'
outputpath = parentpath + '/labels'
classes_txt = parentpath + 'fire_classes.txt'
ext = '.jpg'
if not os.path.exists(outputpath):
os.makedirs(outputpath)
if os.path.isfile(classes_txt):
with open(classes_txt, "r") as f:
class_list = f.read().strip().split()
classes = {k: v for (v, k) in enumerate(class_list)}
xmlPaths = glob(addxmlpath + "/*.xml")
for xmlPath in xmlPaths:
tVocParseReader = PascalVocReader(xmlPath)
shapes = tVocParseReader.getShapes()
with open(outputpath + "/" + os.path.basename(xmlPath)[:-4] + ".txt", "w") as f:
for shape in shapes:
class_name = shape[0]
box = shape[1]
filename = os.path.splitext(addimgpath + "/" + os.path.basename(xmlPath)[:-4])[0] + ext
if class_name not in classes.keys():
classes[class_name] = num_classes
num_classes += 1
class_idx = classes[class_name]
(height, width, _) = cv2.imread(filename).shape
coord_min = box[0]
coord_max = box[2]
xcen = float((coord_min[0] + coord_max[0])) / 2 / width
ycen = float((coord_min[1] + coord_max[1])) / 2 / height
w = float((coord_max[0] - coord_min[0])) / width
h = float((coord_max[1] - coord_min[1])) / height
f.write("%d %.06f %.06f %.06f %.06f\n" % (class_idx, xcen, ycen, w, h))
print(class_idx, xcen, ycen, w, h)
with open(parentpath + "classes.txt", "w") as f:
for key in classes.keys():
f.write("%s\n" % key)
print(key)
3、创建训练、测试图片数据路径, my_data目录下,创建firedata_path.py 文件,会产生3个txt文件:
import xml.etree.ElementTree as ET
import os
sets = ['train', 'val', 'test']
classes = ["fire"]
abs_path = os.getcwd()
def convert(size, box):
dw = 1. / (size[0])
dh = 1. / (size[1])
x = (box[0] + box[1]) / 2.0 - 1
y = (box[2] + box[3]) / 2.0 - 1
w = box[1] - box[0]
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return x, y, w, h
def convert_annotation(image_id):
in_file = open(abs_path + '/Annotations/%s.xml' % (image_id), encoding='UTF-8')
out_file = open(abs_path + '/labels/%s.txt' % (image_id), 'w')
tree = ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
b1, b2, b3, b4 = b
if b2 > w:
b2 = w
if b4 > h:
b4 = h
b = (b1, b2, b3, b4)
bb = convert((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
for image_set in sets:
if not os.path.exists(abs_path + '/labels/'):
os.makedirs(abs_path + '/labels/')
image_ids = open(abs_path + '/ImageSets/Main/%s.txt' % (image_set)).read()
image_ids = image_ids.split('\n')[:-1]
print(image_ids)
list_file = open(abs_path + '/%s.txt' % (image_set), 'w')
for image_id in image_ids:
list_file.write(abs_path + '/JPEGImages/%s.jpg\n' % (image_id))
convert_annotation(image_id)
list_file.close()
4、配置数据集文件 在my_data文件夹下,新建一个fire.yaml 文件
train: /home/cv/PycharmProjects/YOLOv5/yolov5_ultralytics_v5/my_data/fire/train.txt
val: /home/cv/PycharmProjects/YOLOv5/yolov5_ultralytics_v5/my_data/fire/val.txt
test:
nc: 1
names: [ 'fire' ]
训练
在train.py 进行修改,跑个100轮,batch_size看显存:
python train.py --batch 16 --epoch 100 --weights weights/yolov5s.pt --data my_data/fire/fire.yaml --cfg models/yolov5s.yaml
检测
视频:
python detect.py --source data/videos/fire.mp4 --weights runs/train/exp4/weights/best.pt
数据集小啊,好像只能检测红色火焰。
|