IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> mmdetection 训练自己的数据集[v2.14.0 (29/6/2021)] -> 正文阅读

[Python知识库]mmdetection 训练自己的数据集[v2.14.0 (29/6/2021)]

0 环境配置

本篇文章在ubuntu18.04中配置。

主要版本如下:
Python 3.7
Cuda 10.1
Torch 1.7.0
参考mmdetection 环境配置(v2.14.0 (29/6/2021))

1 准备好数据集及标注文件

本篇使用COCO格式的数据集进行训练。
官方需求的格式如下,train2017和val2017放的是对应的照片。

mmdetection
├── mmdet
├── tools
├── configs
├── data
│   ├── coco
│   │   ├── annotations
│   │   │   ├── instances_train2017.json
│   │   │   ├── instances_val2017.json
│   │   ├── train2017
│   │   ├── val2017
│   │   ├── test2017

我的data目录里格式如下,且标注文件为xml格式,需要做一些修改。

├── data
│   ├── images
│   │   ├── 1.jpg
│   │   ├── 2.jpg
......
│   │   ├── 10000.jpg
│   ├── annotations
│   │   ├── 1.xml
│   │   ├── 2.xml
......
│   │   ├── 10000.xml

1.1 首先把图片和标注文件划分为训练集和验证集。
划分后的目录如下所示

├── data
│   ├── images
│   ├── annotations
‘‘‘下面几个文件夹是划分后出现的’’’
│   ├── train_images
│   ├── train_annotations
│   ├── test_images
│   ├── test_annotations

使用的python脚本如下,split_images_and_XmlOrTxt.py,训练集比例默认0.9
可以直接将其放到和images同目录下
命令行运行python plit_images_and_XmlOrTxt.py --all_data_pat "自己存放图片的文件夹,默认images" --all_label_path "自己存放标签的文件夹,默认images"
例如我的存放图片的文件夹名字为my_images, 存放标签文件夹为 my_labels, 标签文件后缀为xml, 则需要执行以下命令python plit_images_and_XmlOrTxt.py --all_data_pat my_images --all_label_path my_labels
当标签后缀不为xml时候,如为txt,则在上述命令后加 --label_suffix txt即可。

#split_images_and_XmlOrTxt.py`
import os
import shutil
import random
from tqdm import tqdm
import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--label_suffix", default='xml' , help = "标签文件后缀  默认xml", type=str)
parser.add_argument("--all_data_path", default='images' , help="自己存放图片的文件夹,默认images", type=str)
parser.add_argument("--all_label_path", default='annotations' ,  help="自己存放标签的文件夹,默认images", type=str)
parser.add_argument("--train_percent", default=0.9 , help="训练集数据的比例,默认0.9", type=int)

args = parser.parse_args()
print('--------------')
print('开始处理')
'''
实现功能, 把images和对应的annotations里面的标签分割成训练集和测试集,
        并把分割后的训练集图片放入train_images、标签放入train_annotations
        测试集图片放入test_images、标签放入test_annotations
--mydata
----images
----annotations
'''
train_percent = args.train_percent #0-1
all_data_path = args.all_data_path
all_label_path = args.all_label_path
LabelSuffix = args.label_suffix   #标签文件后缀  默认xml
train_images_path = 'train_images'
train_labels_path = 'train_annotations'
test_images_path = 'test_images'
test_labels_path = 'test_annotations'

#判断一下这几个目录是否存在,不存在则创建
if  not os.path.exists(train_images_path):
    os.mkdir(train_images_path)
if  not os.path.exists(train_labels_path):
    os.mkdir(train_labels_path)
if  not os.path.exists(test_images_path):
    os.mkdir(test_images_path)
if  not os.path.exists(test_labels_path):
    os.mkdir(test_labels_path)

data_file = os.listdir(all_data_path)
for i in tqdm(data_file):
    shutil.copy(os.path.join(all_data_path, i), os.path.join(train_images_path, i))
    shutil.copy(os.path.join(all_label_path, i[:-3] + LabelSuffix), os.path.join(train_labels_path, i[:-3] + LabelSuffix))
test_file = random.sample(data_file, int((1 - train_percent) * len(data_file)))
for j in tqdm(test_file):
    shutil.move(os.path.join(train_images_path, j), os.path.join(test_images_path, j))
    shutil.move(os.path.join(train_labels_path, j[:-3] + LabelSuffix), os.path.join(test_labels_path, j[:-3] + LabelSuffix))

print('处理完成')
print('--------------')

1.2 把训练集和验证集各自的标注文件转换为coco格式。这里提供一个xml转coco的脚本,命令行运行,输入两个参数,第一个参数是存放标注的文件夹,第二个参数是生成的json文件名。
在运行前记得把文件头部的类别改为自己需要的。

PRE_DEFINE_CATEGORIES = {'liner': 1, 'sailboat': 2, 'warship': 3, 'canoe': 4,
                         'bulk carrier': 5, 'container ship': 6, 'fishing boat': 7} #ToDo

xml2coco.py如下

#xml2coco.py
# pip install lxml

import sys
import os
import json
import xml.etree.ElementTree as ET
from tqdm import tqdm


PRE_DEFINE_CATEGORIES = {'liner': 1, 'sailboat': 2, 'warship': 3, 'canoe': 4,
                         'bulk carrier': 5, 'container ship': 6, 'fishing boat': 7} #ToDo

# If necessary, pre-define category and its id
#  PRE_DEFINE_CATEGORIES = {"aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
                         #  "bottle":5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
                         #  "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
                         #  "motorbike": 14, "person": 15, "pottedplant": 16,
                         #  "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}
START_BOUNDING_BOX_ID = 0

def get(root, name):
    vars = root.findall(name)
    return vars


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars


def get_filename_as_int(filename):
    try:
        filename = os.path.splitext(filename)[0]
        return int(filename)
    except:
        raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename))


def convert(xml_list, xml_dir, json_file):
    list_fp = open(xml_list, 'r')
    json_dict = {"images":[], "type": "instances", "annotations": [],
                 "categories": []}
    categories = PRE_DEFINE_CATEGORIES
    bnd_id = START_BOUNDING_BOX_ID    #每个标注框id唯一
    image_id = 0   #每个图片id唯一
    list_fp = tqdm(list_fp)         #tqdm使用
    for line in list_fp:

        line = line.strip()

        list_fp.set_description("Processing %s" % line)
        xml_f = os.path.join(xml_dir, line)
        tree = ET.parse(xml_f)
        root = tree.getroot()
        path = get(root, 'path')  
        if len(path) == 1:
            filename = os.path.basename(path[0].text)    #xml中path的文件名
        elif len(path) == 0:
            filename = get_and_check(root, 'filename', 1).text
        else:
            raise NotImplementedError('%d paths found in %s'%(len(path), line))
    
        filename = line[:-3] + 'jpg'  

        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        image = {'file_name': filename, 'height': height, 'width': width,
                 'id':image_id}

        json_dict['images'].append(image)
        ## Cruuently we do not support segmentation
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            if category not in categories:
                new_id = len(categories)
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, 'bndbox', 1)

            xmin = int(get_and_check(bndbox, 'xmin', 1).text)
            ymin = int(get_and_check(bndbox, 'ymin', 1).text)
            xmax = int(get_and_check(bndbox, 'xmax', 1).text)
            ymax = int(get_and_check(bndbox, 'ymax', 1).text)
            assert(xmax > xmin)
            assert(ymax > ymin)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            segment = [xmin, ymin, xmin, ymin + o_height, xmin + o_width, ymin + o_height,
                       xmin + o_width, ymin]
            ann = {'area': o_width*o_height, 'iscrowd': 0, 'image_id':
                   image_id, 'bbox':[xmin, ymin, o_width, o_height],
                   'category_id': category_id, 'id': bnd_id, 'ignore': 0,
                   'segmentation': [segment]}
            json_dict['annotations'].append(ann)
            bnd_id = bnd_id + 1
        image_id = image_id + 1

    for cate, cid in categories.items():
        cat = {'supercategory': 'none', 'id': cid, 'name': cate}
        json_dict['categories'].append(cat)
    json_fp = open(json_file, 'w')
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()
    list_fp.close()
    os.remove(sys.argv[1] + '/../xml_list.txt')

if __name__ == '__main__':
    if len(sys.argv) <= 1:
        print('2 auguments are need.')
        print('Usage: %s  XML_DIR OUTPU_JSON.json'%(sys.argv[0]))
        exit(1)
    res = os.listdir(sys.argv[1])
    res.sort()
    with open(sys.argv[1] + '/../xml_list.txt','a') as f:
        for i in range(len(res)):
            f.write(res[i])
            f.write('\n')
            
    convert(sys.argv[1] + '/../xml_list.txt', sys.argv[1], sys.argv[2])

1.3 然后修改文件名,按照如下所示放置数据文件。

mmdetection
├── mmdet
├── tools
├── configs
├── data
│   ├── coco
│   │   ├── annotations
│   │   │   ├── instances_train2017.json
│   │   │   ├── instances_val2017.json
│   │   ├── train2017
│   │   ├── val2017
│   │   ├── test2017

至此,数据集的准备工作完成。


2 修改相关配置文件

这里以faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth模型为例,配置文件为mmdetection/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py
上面的步骤已经准备好了coco数据集,但是官方提供的代码中,class_name和class_num是需要修改的。

2.1 定义数据种类

2.1.1 需要修改的地方在mmdetection/mmdet/datasets/coco.py。把CLASSES的那个tuple改为自己数据集对应的种类tuple即可。

class CocoDataset(CustomDataset):
    CLASSES = ('liner', 'sailboat', 'warship', 'canoe', 'bulk carrier', 'container ship', 'fishing boat')

注意:如果只有一个类,要加上一个逗号,否则将会报错。

2.1.2mmdetection/mmdet/core/evaluation/class_names.py修改coco_classes数据集类别,这个关系到后面test的时候结果图中显示的类别名称。例如:
去对应目录寻找

def coco_classes():
    return [
        'liner', 'sailboat', 'warship', 'canoe', 'bulk carrier', 'container ship', 'fishing boat'
    ]

2.2 修改模型文件

(因为模型为faster_rcnn_r101_fpn_1x,所以配置对应的config文件)

找到mmdetection/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py文件,打开后发现

_base_ = './faster_rcnn_r50_fpn_1x_coco.py'

继续沿着这个路径去找文件,发现指向的以下文件

_base_ = [
    '../_base_/models/faster_rcnn_r50_fpn.py',   
    #指向的是model dict,修改其中的num_classes类别为自己的类别。
    
    '../_base_/datasets/coco_detection.py',
    #data dict中的workers_per_gpu=2设置为0,train_pipeline和test_pipeline中的img_scale根据自己的图片尺寸修改。
    
    '../_base_/schedules/schedule_1x.py',
    #optimizer dict中修改学习率lr。当gpu数量为8时,lr=0.02;当gpu数量为4时,lr=0.01;我只有一个gpu,所以设置lr=0.0025
    
     '../_base_/default_runtime.py'
]

(如果是尝试训练,以上修改就可以,如果想要优化自己的模型等等,可以参考这个博客的代码详解这里


3 开始训练

训练前在mmdetection的目录下新建work_dirs文件夹。

重要:若改动框架源代码后,一定要注意重新编译后再使用。类似这里修改了几个源代码文件后再使用train命令之前,先要编译,执行下面命令。

pip install -v -e .  # or "python setup.py develop"

最后,执行训练指令。

python tools/train.py configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py   --work-dir work_dirs 






最后:遇到的一些问题以及解决方法。

1、如果出现model num_class == 80 , num_class == 15的一个错误,很可能是你修改完类别后没有重新编译,此时重新执行一下pip install -v -e . # or "python setup.py develop" 应该就可以解决问题。

2、 问题:

class_name

rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) File 
"pycocotools/_mask.pyx", line 292, in pycocotools._mask.frPyObjects 
IndexError: list index out of range

解决方法:(我在使用mask_rcnn模型训练时候碰到,因为coco数据集中的segmentation代表边界点,一般是分割时候使用,所以在转格式时候就把segmentation置空了,结果就出现这个错误。后来在矩形框中取了四个定点,四条边的中点共8个点,然后就不报错了)。即instances_train2017.json文件中segmentation不能为空。

3、 问题:

File "/home/chen/anaconda3/envs/mmdet/lib/python3.7/site-packages/mmcv/image/geometric.py", line 517, in impad
    value=pad_val)
cv2.error: OpenCV(4.5.2) /tmp/pip-req-build-947ayiyu/opencv/modules/core/src/copy.cpp:1026: error: (-215:Assertion failed) top >= 0 && bottom >= 0 && left >= 0 && right >= 0 && _src.dims() <= 2 in function 'copyMakeBorder'

解决方法:正在寻找解决方法…

参考博客
1、mmdetection训练自己的coco数据

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-07-15 16:07:41  更:2021-07-15 16:08:22 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/25 14:39:26-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计