如有错误,恳请指出。
这篇文章用来记录一下yolov5在训练过程中提出的一个图片采样策略,简单来说,就是根据图片的权重来决定其采样顺序。
1. 图片采样策略想法
在我们训练数据集的时候,一般是对数据集随机采样几张图像然后构建成一个mini-batch 来批量输入网络处理。个人猜想,一个可能的想法就是,这种随机的图像采集会不会过于随意,因为有些图像的目标是过少的,那么这种图像可能对网络来说比较简单;而有些图像的目标是比较多的,这种是比较困难的。而对于开始训练的初期就使用这种简答图像对网络的训练可能带来不了多大的学习提升。
所以,如果可以对数据集中的每张图像做一个权重的划分,在训练模型的时候依照图像的权重大小依次按难到易的大概顺序来进行训练,让模型从一开始的困难的样本较快的学习到潜在特征,到之后通过简单的图像样本来对参数进行微调,说不定是一个好的方法。
(以上内容是个人的思考猜测,可能是有误的,欢迎探讨。)
那么具体的实现思路就是,对整个数据集的图像目标做类别统计,然后类别的数目越大权重越小(成反比的关系)。然后再使用整个数据集的类别权重对每一张图像做类别权重的叠加。也就是根据每一张的图片的类别权重和来作为采样的权重,决定其采用的顺序。在代码的实现中是从大到小排序的。
2. 图片采样策略代码
大概的注释都写在代码里了:
def train():
...
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
...
for epoch in range(start_epoch, epochs):
model.train()
if opt.image_weights:
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)
...
def labels_to_class_weights(labels, nc=80):
if labels[0] is None:
return torch.Tensor()
labels = np.concatenate(labels, 0)
classes = labels[:, 0].astype(np.int)
weights = np.bincount(classes, minlength=nc)
weights[weights == 0] = 1
weights = 1 / weights
weights /= weights.sum()
return torch.from_numpy(weights)
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
return image_weights
class LoadImagesAndLabels(Dataset):
def __init__(self, img_size=640, batch_size=16, image_weights=False, ...):
...
self.indices = range(n)
def __len__(self):
return len(self.img_files)
def __getitem__(self, index):
index = self.indices[index]
img, labels = load_mosaic(self, index)
...
return torch.from_numpy(img), labels_out, self.img_files[index], shapes
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment,
hyp=hyp,
rect=rect,
cache_images=cache,
single_cls=single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
batch_size = min(batch_size, len(dataset))
nw = 0
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
dataloader = loader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
return dataloader, dataset
所以从代码中可以看见,如果不使用图像采样策略,这里也不会使用随机的选择策略,而且index从0开始提取,验证如下:
第一次断点调试:index从0开始,想法验证成功
参考资料:
1. 【YOLOV5-5.x 源码解读】general.py
|