源码:datasets.py
- 获取图片路径;
def __getitem__(self, index):
img_path = self.img_files[index % len(self.img_files)].rstrip()
- 读取图片,转为
RGB 格式,并且数据格式转为tensor ;
img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
- 数据预处理,如果数据不是三个
channels ,转换成三个;
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand((3, img.shape[1:]))
- 数据如果不是正方形的,做
pad 处理,长方形转成正方形,缺失的做一个填补;
_, h, w = img.shape
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
# Pad to square resolution
img, pad = pad_to_square(img, 0)
_, padded_h, padded_w = img.shape
- 读取标签;
# ---------
# Label
# ---------
label_path = self.label_files[index % len(self.img_files)].rstrip()
label_path = 'E:\\eclipse-workspace\\PyTorch\\PyTorch-YOLOv3\\data\\coco\\labels' + label_path
#print (label_path)
targets = None
if os.path.exists(label_path):
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
- 得到当前坐标;
# Extract coordinates for unpadded + unscaled image
x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
- 因为图像上面做了pad处理,长方形变成了正方形,所以标签数据集的坐标也要做一下处理;
# Adjust for added padding
x1 += pad[0]
y1 += pad[2]
x2 += pad[1]
y2 += pad[3]
- yolo-v3的论文中,我们要预测的标签值不是x1、y1、x2、y2,而是一个中心点和w、h,所以要进行转换;
# Returns (x, y, w, h)
boxes[:, 1] = ((x1 + x2) / 2) / padded_w
boxes[:, 2] = ((y1 + y2) / 2) / padded_h
boxes[:, 3] *= w_factor / padded_w
boxes[:, 4] *= h_factor / padded_h
targets = torch.zeros((len(boxes), 6))
targets[:, 1:] = boxes
- 选择是否做图像增强;
# Apply augmentations
if self.augment:
if np.random.random() < 0.5:
img, targets = horisontal_flip(img, targets)
- 返回图片路径,图像数据,pad变换后的坐标;
return img_path, img, targets
|