IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样 -> 正文阅读

[人工智能]YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样


如有错误,恳请指出。


这篇文章用来记录一下yolov5在训练过程中提出的一个图片采样策略,简单来说,就是根据图片的权重来决定其采样顺序。


1. 图片采样策略想法

  • 图片采样策略想法

在我们训练数据集的时候,一般是对数据集随机采样几张图像然后构建成一个mini-batch来批量输入网络处理。个人猜想,一个可能的想法就是,这种随机的图像采集会不会过于随意,因为有些图像的目标是过少的,那么这种图像可能对网络来说比较简单;而有些图像的目标是比较多的,这种是比较困难的。而对于开始训练的初期就使用这种简答图像对网络的训练可能带来不了多大的学习提升。

所以,如果可以对数据集中的每张图像做一个权重的划分,在训练模型的时候依照图像的权重大小依次按难到易的大概顺序来进行训练,让模型从一开始的困难的样本较快的学习到潜在特征,到之后通过简单的图像样本来对参数进行微调,说不定是一个好的方法。

(以上内容是个人的思考猜测,可能是有误的,欢迎探讨。)

  • 图片采样策略思路

那么具体的实现思路就是,对整个数据集的图像目标做类别统计,然后类别的数目越大权重越小(成反比的关系)。然后再使用整个数据集的类别权重对每一张图像做类别权重的叠加。也就是根据每一张的图片的类别权重和来作为采样的权重,决定其采用的顺序。在代码的实现中是从大到小排序的。


2. 图片采样策略代码

  • yolov5参考代码

大概的注释都写在代码里了:

def train():
	...
	model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
	...
	for epoch in range(start_epoch, epochs): 
		model.train()

		# Update image weights (optional, single-GPU only)
        if opt.image_weights:
            # 根据数据集的类别数目构建每个类别的权重(类别权重与类别数目成反比)
            cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  # class weights
            # 对每张图片的目标计算其类别权重和作为图片的采集权重
            iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights
            # 再更具每张图片的采集权重来构建图片的采样顺序
            dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx
	...


def labels_to_class_weights(labels, nc=80):
    # Get class weights (inverse frequency) from training labels
    # labels是当前数据集的训练集的所有图像: {list: 682}
    # 列表的每个对象格式是: (ndarray: (k, 5)) k表示当前图像的目表个数, 5是(class+xywh)
    if labels[0] is None:  # no labels loaded
        return torch.Tensor()
    # 把图像的标签列表直接转化为标签列表:{ndarray: (labels, 5)} labels表示全部图像的所有标签个数
    labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO
    # 提取类别 labels[:, 0] 数据来为每一类做统计 .astype(np.int): 取整
    classes = labels[:, 0].astype(np.int)  # labels = [class xywh]
    # weight: 统计每个类别出现的次数
    weights = np.bincount(classes, minlength=nc)  # occurrences per class

    # Prepend gridpoint count (for uCE training)
    # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum()  # gridpoints per image
    # weights = np.hstack([gpi * len(labels)  - weights.sum() * 9, weights * 9]) ** 0.5  # prepend gridpoints to start

    # 将出现次数为0的类别权重全部取1
    weights[weights == 0] = 1  # replace empty bins with 1
    # 类别权重取类别出现次数的倒数, 也就是表示类别次数与权重成反比, 标签频率越高的类别权重越低, 因为越不罕见
    weights = 1 / weights  # number of targets per class
    # 归一化操作: 求出每一类别的占比
    weights /= weights.sum()  # normalize
    return torch.from_numpy(weights)  # numpy -> tensor


def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
    # Produces image weights based on class_weights and image contents
    # out:{ndarray: (682,3)} 统计每一张图片中类类别的数目 这里我用的是mask数据集有3个类别 每个位置存储图像中对应类别目标出现的个数
    class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
    # class_weights:[n_class] -> [1, n_class]
    # 每张图片的每个类别个数[label_nums, n_class] * 整个数据集每个类别的权重[1, n_class] = 每张图片的对应每个类别的权重[label_nums, n_class_weight]
    # 然后每个类别的权重加在一起等于当前这张图片的权重
    image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
    # index = random.choices(range(n), weights=image_weights, k=1)  # weight image sample
    return image_weights
  • 构造Dataset使用的地方
class LoadImagesAndLabels(Dataset):
	def __init__(self, img_size=640, batch_size=16, image_weights=False, ...):
		...
		self.indices = range(n)
	
	def __len__(self):
        return len(self.img_files)
	
    def __getitem__(self, index):
    	# 重点使用部分, 就是用权重采样策略替代了随机采样
    	# 随机采样: index返回的是随机值(shuffle = True),所以注意到其实在
    	# 权重采样: index是按顺序从0开始, 然后依次提取indices所指向的图像索引
        index = self.indices[index]  # linear, shuffled, or image_weights
        img, labels = load_mosaic(self, index)
        ...
        return torch.from_numpy(img), labels_out, self.img_files[index], shapes

# 因为可以注意到, 构建dataloader的时候yolov5代码中是没有使用shuffle=True这个随机采样的参数的
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
                      rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
    # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
    with torch_distributed_zero_first(rank):
        dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,  # augmentation hyperparameters
                                      rect=rect,  # rectangular training
                                      cache_images=cache,
                                      single_cls=single_cls,
                                      stride=int(stride),
                                      pad=pad,
                                      image_weights=image_weights,
                                      prefix=prefix)

    batch_size = min(batch_size, len(dataset))

    # 这里对num_worker进行更改
    # nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers])  # number of workers
    nw = 0  # 可以适当提高这个参数0, 2, 4, 8, 16…

    sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
    loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
    # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
    
    # 没有使用 shuffle=True 这个参数
	dataloader = loader(dataset,
                        batch_size=batch_size,
                        num_workers=nw,
                        sampler=sampler,
                        pin_memory=True,
                        collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
    return dataloader, dataset

所以从代码中可以看见,如果不使用图像采样策略,这里也不会使用随机的选择策略,而且index从0开始提取,验证如下:

第一次断点调试:index从0开始,想法验证成功

在这里插入图片描述


参考资料:

1. 【YOLOV5-5.x 源码解读】general.py

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-16 21:42:28  更:2022-06-16 21:43:53 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 2:24:45-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码