IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于mmdetection 旋转目标检测(OBB detection)+DOTA数据集&自定义数据集 -> 正文阅读

[人工智能]基于mmdetection 旋转目标检测(OBB detection)+DOTA数据集&自定义数据集

这两周得益于组里的任务需求,肝了一个遥感类的飞机旋转框目标检测,在给定的4096*4096的大尺度分辨率图片上去识别检测飞机。

经过模型检测后输出结果图如下图所示:
img_output_example
可以看到最终的结果还是不错的,通过该任务的锻炼,自己对一般的目标检测工程上的问题可以说基本走了个遍,能够完成基本的目标检测、旋转框的目标检测任务等。在这里简单分享一下任务的心得。

核心思想

	一:基于mmdetection的目标检测框架
	二:DOTA数据集格式
	三:任务相关

一:基于mmdetection的目标检测框架

(1)mmdetection相关
做目标检测现在已经绕不开mmdetection了,该框架是一个基于Pytorch实现的深度学习目标检测工具箱,与MMCV进行搭配使用。目前许多SOTA的模型都在上面进行更改。
一些教程:
MMDetection中文文档—详解:https://zhuanlan.zhihu.com/p/101225733
数据处理过程:https://blog.csdn.net/u014453898/article/details/107701094
configs相关:https://zhuanlan.zhihu.com/p/102072353

由于任务是检测目标的旋转框,最终选择了s2anet模型。对于一般的目标检测任务(如coco,VOC等)则尝试使用了VarifocalNet。
s2anet:https://github.com/csuhan/s2anet
varifocalNet:https://github.com/hyz-xmaster/VarifocalNet

mmdetection使用的关键步骤在于定义config文件,框架会从config中定义好的字段中去加载使用、定义相应的函数、模型、数据加载、数据预处理、优化器、路径等。一般而言需要修改模型中的num_classes,其值为数据集中分类的类别个数(较老的mmdet的版本需要个数+1,即加一个背景类)、数据集的加载路径、work_dir等。

(2)自定义数据集
自定义的数据集类定义在mmdet/datasets中,一般而言是将数据集转换成COCO(or voc)格式,然后继承
已经写好的CocoDataset(CustomDataset)。将类中的CLASSES属性修改成自定义数据集中的类别。在自定义的类的上一行需要加入@DATASETS.register_module (mmdet版本不一样可能有所区别) 用来注册已经自定义好的类,同时需要在 datasets/__init__py中加入自定义的数据集类。如下图所示:
自定义数据集
同时可以自己重写evaluation函数满足自己的评估需求。

二:DOTA数据集格式

(1)DOTA数据集简介
DOTA数据集是一个比较著名的遥感类高分辨率数据集,包括v1.0,v1.5,v2.0三个版本的数据,一共30G左右。采用旋转框的标记方式,标记四个顶点八个坐标得到不规则四边形。具体实现是,首先标注出一个初始点,为(x1,y1),然后顺时针方向依次标注2、3、4共4个点。如下图ace所示。bcd是传统的水平标注方法,有大量的重叠区域。
DOTA数据集标注样例
标注文件的格式如下图所示:DOTA_label
其中(x1,y1)用于表示OBB的顶点起始位置,四个顶点按照顺时针进行排列。category表示目标种类,difficult表示实例的检测难度。

DOTA_devkit是官方给的配套的数据处理的配套文件,包括绘制目标边框的示例,剪裁数据集、合并检测结果、评估模型性能等。
DOTA_devkit官方github:https://github.com/CAPTAIN-WHU/DOTA_devkit

下面这个DOTA_devkit的整理(踩坑记录)一文里详细介绍了DOTA_devkit的各个py文件的作用、代码中的实际应用、剪裁、合并策略等,介绍的比较全面。
DOTA_devkit的整理(踩坑记录):https://zhuanlan.zhihu.com/p/355862906

(2)标签格式转换
由于任务所给的数据集同样是遥感图片,且是大分辨率图片,目标也是标注OBB,因此可以类比于DOTA数据集的操作。第一步是将任务的数据集从labelme的标注格式转换至DOTA标注格式,然后采用官方给的DOTA_devkit进行图片的预处理操作。
同时,DOTA_devit的dota_evaluation_task1.py中的voc_eval()、即数据集的评估函数中,还需要提供一个测试图片的name_list和储存剪裁前的图片注释文件夹label_txt,其中name_list需要自己写一个脚本生成,如下所示:

import os
dir="/home/dataset/airplaneDOTA/airplane/val/images"
img_name_list=[]
for root, dirs, files in os.walk(dir):
  for file in files:
   # print os.path.join(root,file)
   img_name=file.split(".")[0]
   img_name_list.append(img_name)

write_path="/home/dataset/airplaneDOTA/airplane/val/test_image_list.txt"
#写入文本
with open(write_path,"w") as f:
    for i in range(len(img_name_list)):
        f.write(img_name_list[i]) 
        f.write("\n")
print("end")

在s2anet中,则是在evaluation中提供。

evaluation = dict( gt_dir='/home/dataset/airplaneDOTA/airplane/val/labelTxt/',# change it to valset for offline validation
imagesetfile='/home/dataset/airplaneDOTA/airplane/val/test_image_list.txt')

(3)图片裁剪策略
由于训练成本的问题,难以直接将4096x4096的图片直接输入到网络中训练,因此需要将图片进行相应的调整。最直观的方法是在图片预处理中将图片直接resize成1024x1024(or更低)的大小,但是这种方法会使一些本就size较小的目标在训练时更小,从而导致模型的训练性能受损,因此该种方法在实际应用中不进行考虑,而是选择合适的裁剪策略。

图片剪裁策略是将4096x4096的图片裁剪成1024x1024的图片。单纯的将4096x4096的图片按比例裁剪成16张1024x1024的图片明显不是个好的方法,因为如果目标正好位于两张图片的交界位置,那么两个图片各有一半的目标会大大影响裁剪性能。一种好的策略是使得裁剪后的图片有部分重合的像素,这样能够很大概率保证待检测的目标能够在某张或者多张裁剪后的图片里完整,同时也能够起到数据增强的目的。在实际应用中,剪裁图片重合的面积越大,实际效果越好(本任务中每条边重合512个像素,50%)。

同时,为了帮助模型训练多尺度的目标,将裁剪后的图片缩放至0.5倍、1.5倍并进行存储,使得数据集中包含同一张图片的0.5、1、1.5倍三种尺度比例的图片。

然后这种策略会导致数据集扩充比较大,一张4096x4096能够裁剪出100多张图片,但是该种方式会使得模型性能提高很多。使用的是ReDet中实现的prepare_dota1_5_v2.py的代码来进行裁剪,github如下:

https://github.com/csuhan/ReDet/blob/master/DOTA_devkit/prepare_dota1_5_v2.py

三:任务相关

(1)自定义数据集格式转换
如前文所言,对于任务所给的自定义数据集,最好的方法是将其转换成现有的写好的数据集标签格式进行训练,这样就免去自定义dataloader的烦恼。如下是常见的目标检测的格式转换code:

目标检测常见数据格式转换:https://github.com/spytensor/prepare_detection_dataset
格式转换完成后,自定义类里只需要继承相应的类,并在mmdet中register一下即可。

(2)mmdetection的版本以及适配CUDA的问题
mmdetecion目前来说已经趋于稳定,但是之前每个大版本之间还是差的有点多的,比如0.几版本都不包含@DATASETS.register_module(),只能在执行setup.py的时候进行注册(无法动态加载模块)。同时还需要适配不同版本的mmcv(mmcv-full)。同时,由于mmdetecion框架安装的时候需要自拟脚本,对CUDA的版本、torch的版本还有一定的要求。笔者在跑s2anet的时候使用的是官方的10.1版本的cuda以及1.3版本的torch,但由于任务的docker要求,需要适配cuda11与torch1.7版本,在”升级的时候“需要修改一下mmdetection安装时的setup.py文件以及相应的torch版本的不同带来的问题,花了不少时间改了很多bug才适配完成。可以参考以下网址:
https://github.com/open-mmlab/mmdetection/issues/3363

https://github.com/pytorch/pytorch/issues/52669

(3)图片预测
s2anet中给出了图片inference的代码示例。给定一张图片, 返回经过检测后的画有bounding box的图片。将待检测图片存入至img_dir的路径中,在out_dir中给出预测的图片。其中图片的预处理方式则采用的是config中data.test中的方式。

import argparse
import os
import os.path as osp
import pdb
import random

import cv2
import mmcv
from mmcv import Config

from mmdet.apis import init_detector, inference_detector
from mmdet.core import rotated_box_to_poly_single
from mmdet.datasets import build_dataset

def show_result_rbox(img,
                     detections,
                     class_names,
                     scale=1.0,
                     threshold=0.2,
                     colormap=None,
                     show_label=False):
    assert isinstance(class_names, (tuple, list))
    if colormap:
        assert len(class_names) == len(colormap)
    img = mmcv.imread(img)
    color_white = (255, 255, 255)

    for j, name in enumerate(class_names):
        if colormap:
            color = colormap[j]
        else:
            color = (random.randint(0, 256), random.randint(0, 256), random.randint(0, 256))
        try:
            dets = detections[j]
        except:
            pdb.set_trace()
        # import ipdb;ipdb.set_trace()
        for det in dets:
            score = det[-1]
            det = rotated_box_to_poly_single(det[:-1])
            bbox = det[:8] * scale
            if score < threshold:
                continue
            bbox = list(map(int, bbox))
     #       print(bbox)
            #[2482, 2230, 2550, 2239, 2542, 2301, 2474, 2292]坐标
            for i in range(3):
                cv2.line(img, (bbox[i * 2], bbox[i * 2 + 1]), (bbox[(i + 1) * 2], bbox[(i + 1) * 2 + 1]), color=color,
                         thickness=2, lineType=cv2.LINE_AA)
            cv2.line(img, (bbox[6], bbox[7]), (bbox[0], bbox[1]), color=color, thickness=2, lineType=cv2.LINE_AA)
            if show_label:
                cv2.putText(img, '%s %.3f' % (class_names[j], score), (bbox[0], bbox[1] + 10),
                            color=color_white, fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5)
    return img


def save_det_result(config_file, out_dir, checkpoint_file=None, img_dir=None, colormap=None):
    cfg = Config.fromfile(config_file)
    data_test = cfg.data.test
    dataset = build_dataset(data_test)
    classnames = dataset.CLASSES
  #  print(classnames)
    # use checkpoint path in cfg
    if not checkpoint_file:
        checkpoint_file = osp.join(cfg.work_dir, 'latest.pth')
  
    # use testset in cfg
    if not img_dir:
        img_dir = data_test.img_prefix

    model = init_detector(config_file, checkpoint_file, device='cuda:0')

    img_list = os.listdir(img_dir)
    for img_name in img_list:
        img_path = osp.join(img_dir, img_name)
        img_out_path = osp.join(out_dir, img_name)
        result = inference_detector(model, img_path)
        img = show_result_rbox(img_path,
                               result,
                               classnames,
                               scale=1.0,
                               threshold=0.5,
                               colormap=colormap)
  #      print(result)
        cv2.imwrite(img_out_path, img)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='inference demo')
    parser.add_argument('--config_file', help='input config file',default="s2anet_dota.py")
    parser.add_argument('--model', help='pretrain model',default="./work_dir/s2anet/latest.pth")
    parser.add_argument('--img_dir', help='img dir',default="example")
    parser.add_argument('--out_dir', help='output dir',default="example_result")
    args = parser.parse_args()

    dota_colormap = [
        (54, 67, 244),
        (99, 30, 233),
        (176, 39, 156),
        (183, 58, 103),
        (181, 81, 63),
        (243, 150, 33),
        (212, 188, 0),
        (136, 150, 0),
        (80, 175, 76),
        (74, 195, 139),
        (57, 220, 205),
        (59, 235, 255),
        (0, 152, 255),
        (34, 87, 255),
        (72, 85, 121)]

    hrsc2016_colormap = [(212, 188, 0)]
    save_det_result(args.config_file, args.out_dir, checkpoint_file=args.model, img_dir=args.img_dir,
                    colormap=dota_colormap)

总结

通过此次任务对于目标检测、尤其是遥感目标检测的一般方法有了比较清晰的认识、对于mmdetecion也框架也有了比较深刻的理解。对目标检测各个常见数据都跑了一遍,数据处理了一遍,也掌握了自定义数据集的数据处理方法,未来可以很快上手,完成相似的任务。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-27 12:50:24  更:2021-10-27 12:51:09 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 8:09:46-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码