xml2txt.py
"""
功能
1.根据train.txt和val.txt将voc数据集标注信息(.xml)转为yolo标注格式(.txt),生成dataset文件(train+val)
"""
import os
from tqdm import tqdm
from lxml import etree
import json
import shutil
from os.path import *
dir_path = dirname(dirname(abspath(__file__)))
images_path = os.path.join(dir_path, "Datasets", "Images")
xml_path = os.path.join(dir_path, "Datasets", "Annotations")
train_txt_path = os.path.join(dir_path, "Datasets", "ImageSets", "train.txt")
val_txt_path = os.path.join(dir_path, "Datasets", "ImageSets", "val.txt")
test_txt_path = os.path.join(dir_path, "Datasets", "ImageSets", "test.txt")
label_json_path = os.path.join(dir_path, "pest_classes.json")
save_file_root = os.path.join(dir_path, "pest")
assert os.path.exists(images_path), "images path not exist..."
assert os.path.exists(xml_path), "xml path not exist..."
assert os.path.exists(train_txt_path), "train txt file not exist..."
assert os.path.exists(val_txt_path), "val txt file not exist..."
assert os.path.exists(test_txt_path), "test txt file not exist..."
assert os.path.exists(label_json_path), "label_json_path does not exist..."
if os.path.exists(save_file_root) is False:
os.makedirs(save_file_root)
def parse_xml_to_dict(xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if len(xml) == 0:
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
def translate_info(file_names: list, save_root: str, class_dict: dict, train_val_test='train'):
"""
将对应xml文件信息转为yolo中使用的txt文件信息 xml to txt xyxy to xywh(normalization)
:param file_names: train/val/test的所有图片名 如:['20210819B000001', '20210819B000002', '20210819B000004',...]
:param save_root: 新数据集的root目录
:param class_dict: 新数据集的label字典 如:{'powdery_mildew': 0, 'leaf_miner': 1, 'anthracnose': 2}
:param train_val_test: 是什么数据 train or val or test
:return:
"""
save_txt_path = os.path.join(save_root, train_val_test, "labels")
if os.path.exists(save_txt_path) is False:
os.makedirs(save_txt_path)
save_images_path = os.path.join(save_root, train_val_test, "images")
if os.path.exists(save_images_path) is False:
os.makedirs(save_images_path)
for file in tqdm(file_names, desc="translate {} file...".format(train_val_test)):
img_path = os.path.join(images_path, file + ".jpg")
assert os.path.exists(img_path), "file:{} not exist...".format(img_path)
xml_full_path = os.path.join(xml_path, file + ".xml")
assert os.path.exists(xml_full_path), "file:{} not exist...".format(xml_full_path)
with open(xml_full_path, encoding='UTF-8') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = parse_xml_to_dict(xml)["annotation"]
img_height = int(data["size"]["height"])
img_width = int(data["size"]["width"])
with open(os.path.join(save_txt_path, file + ".txt"), "w") as f:
assert "object" in data.keys(), "file: '{}' lack of object key.".format(xml_full_path)
for index, obj in enumerate(data["object"]):
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
class_name = obj["name"].strip()
class_index = class_dict[class_name]
xcenter = xmin + (xmax - xmin) / 2
ycenter = ymin + (ymax - ymin) / 2
w = xmax - xmin
h = ymax - ymin
xcenter = round(xcenter / img_width, 6)
ycenter = round(ycenter / img_height, 6)
w = round(w / img_width, 6)
h = round(h / img_height, 6)
info = [str(i) for i in [class_index, xcenter, ycenter, w, h]]
if index == 0:
f.write(" ".join(info))
else:
f.write("\n" + " ".join(info))
shutil.copyfile(img_path, os.path.join(save_images_path, img_path.split(os.sep)[-1]))
if __name__ == "__main__":
json_file = open(label_json_path, 'r')
class_dict = json.load(json_file)
with open(train_txt_path, "r") as r:
train_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
translate_info(train_file_names, save_file_root, class_dict, "train")
with open(val_txt_path, "r") as r:
val_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
translate_info(val_file_names, save_file_root, class_dict, "val")
with open(test_txt_path, "r") as r:
test_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
translate_info(test_file_names, save_file_root, class_dict, "test")
生成新数据集目录:
|