数据增强
深度学习模型的鲁棒性(robustness)和泛化性(generality)受到训练数据的多样性和数据量所影响。数据增强(data augmentation)是机器学习和深度学习中经常采用的一个方法,其目的是扩大训练样本的数量。
语义分割是计算机一个重要的下游任务,语义分割的数据增强通常需要对图像及其对应的标签做相同的增强处理
本文总结了3种常用的增强方式:(1)旋转,(2)翻转,(3)裁剪。所有操作均采用opencv库进行
首先使用opencv定义数据读取和保存函数。
import cv2
import os
import numpy as np
___________________________________________________________
def read_data(file, mode=1):
"""
Args:
file: 数据路径
mode: bool值,若读取3通道则1,读取灰度图则为0
Returns:
"""
if mode == 1:
img = cv2.imread(file)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
if mode == 0:
img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
return img
else:
raise ValueError("mode should be a bool number 1 or 0")
def save_data(save_pth, img):
img = img.astype(np.uint8)
cv2.imwrite(save_pth, img)
旋转操作(Rotate)
def rotate(img, gt, angle=10):
"""
angle: 旋转的角度
"""
img = read_data(img, 1)
gt = read_data(gt, 0)
assert img.shape[:2] == gt.shape[:2]
h, w = img.shape[:2]
center = (w / 2, h / 2)
mat = cv2.getRotationMatrix2D(center, angle, scale=1)
rotated_img = cv2.warpAffine(img, mat, (h, w))
rotated_gt = cv2.warpAffine(gt, mat, (h, w))
return rotated_img, rotated_gt
翻转操作(flip)
def flip(img, gt, direction=1):
"""
Args:
img:
direction: bool, 1表示水平翻转,0表示垂直翻转
Returns:
"""
if type(img) == str:
img = read_data(img, 1)
gt = read_data(gt, 0)
assert img.shape[:2] == gt.shape[:2]
assert img.shape[:2] == gt.shape[:2]
flipped_img = cv2.flip(img, direction)
flipped_gt = cv2.flip(gt, direction)
return flipped_img, flipped_gt
裁剪操作(crop)
def crop(img, gt):
img = read_data(img, 1)
gt = read_data(gt, 0)
assert img.shape[:2] == gt.shape[:2]
upL_subim, upL_subgt = img[:512, :512, :], gt[:512, :512]
upR_subim, upR_subgt = img[:512, -512:, :], gt[:512, -512:]
bottomL_subim, bottomL_subgt = img[-512:, :512, :], gt[-512:, :512]
bottomR_subim, bottomR_subgt = img[-512:, -512:, :], gt[-512:, -512:]
(h, w) = img.shape[:2]
h_ctr, w_ctr = int(h/2), int(w/2)
center_subim, center_subgt = img[(h_ctr - 256):(h_ctr + 256), (w_ctr - 256):(w_ctr + 256), :],\
gt[(h_ctr - 256):(h_ctr + 256), (w_ctr - 256):(w_ctr + 256)],
(w - 256):(w + 256)]
return (upL_subim, upL_subgt), (upR_subim, upR_subgt), (bottomL_subim, bottomL_subgt), (
bottomR_subim, bottomR_subgt), (center_subim, center_subgt)
|