最近做语义分割遇到数据增强如何同时给原始图像和标签同时进行相同变换的问题,我之前在网上搜了下,能找到的方案虽然实现了但看起来很笨拙,所以我把我的方法分享一下。
先看结果:
左边是经过变换的道路标签,右边是经过变换的原始图像,可以看出两者都往左移动了一下(看图像黑色边缘)(其实两者还进行了随机平移、随机旋转等操作,只不过随机量一样看不出来),怎么实现的呢?
只需要在修改自定义data.Dataset类的内容即可
在它的def __ init __(self)里面加入
rot_degree = random.choice([0, 90, 180, 270]) # 随机旋转角度
sat_enhance_list = [
transforms.RandomHorizontalFlip(0.5), # 0.5的概率随机左右翻转
transforms.RandomVerticalFlip(0.5), # 0.5的概率随机上下翻转
transforms.RandomRotation((rot_degree, rot_degree)), # 随机旋转
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),# 随机水平、垂直平移-0.1~0.1,随机缩放到0.9~1.1
# transforms.ToTensor(), # 转换为张量,[0,255]->[0,1]
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化,[0,1]->[-1,1]
] # 遥感图像有而道路图像没有的数据增强
road_enhance_list = [
transforms.RandomHorizontalFlip(0.5), # 0.5的概率随机左右翻转
transforms.RandomVerticalFlip(0.5), # 0.5的概率随机上下翻转
transforms.RandomRotation((rot_degree, rot_degree)), # 随机旋转
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), # 随机平移-0.1~0.1距离,随机缩放到0.9~1.1
# transforms.ToTensor() # 转换为张量,[0,255]->[0,1]
] # 遥感图像和道路图像的数据增强
self.transform_sat = transforms.Compose(sat_enhance_list)
self.transform_road = transforms.Compose(road_enhance_list)
在它的def __ getitem __(self,index)里面加入
sat_data = Image.open(join(self.a_path, self.img_filenames[index][0])).convert('RGB')
road_data = Image.open(join(self.b_path, self.img_filenames[index][0])).convert('L')
sat_data=transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.1)(sat_data) # 随机改变亮度、对比度、饱和度、色相
seed = np.random.randint(2147483647) # 随机数种子
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
sat_data = self.transform_sat(sat_data) # 遥感图像数据增强
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
road_data = self.transform_road(road_data) # 道路图像数据增强
sat_data.show()
road_data.show()
有几点需要注意的地方:
1、我是为了方便展示,所以注释了transforms.ToTensor()、transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))并加入的sat_data.show()、road_data.show()(Image.show()里面要是Image数据类型而不能是tensor否则会报错),正常运行要注释sat_data.show()、road_data.show(),并取消注释transforms.ToTensor()、transforms.Normalize((0.5, 0.5, 0.5)
2、可以看到我在对sat_data用transform_sat()进行变换前加了transforms.ColorJitter()变换,可以把ColorJitter()加入到self.transform_sat()里面吗?不可以放在开头,但可以放在末尾,为什么呢?我之前把它放在开头,得到的结果显示原始图像和标签的变换不一样,我想了一下,可能原因是这样的:给torch了随机数种子seed后,torch在经过一个随机过程后随机数种子变成了seed_2(并且相同的seed得到的seed_2是一样的),以此类推,seed_3,seed_4…,所以要保证我的随机操作RandomHorizontalFlip、RandomVerticalFlip、.RandomRotation、RandomAffine采用的随机数一样,就得满足:(1)这些随机操作相同并且排列顺序一样(2)第一个随机操作对应的种子一样。 所以回到问题,我想对原始图像进行transforms.ColorJitter()变换而不对标签进行这样的变换,就不应该把它放在sat_enhance_list里面第一个,可以放在最后,但考虑到放在最后的话原始图像经过随机平移后产生的黑色边缘经过色彩抖动后会变成其他颜色,感觉不太好,所以是采用了上述代码所用的方法,原始图像在赋予seed前进行色彩抖动
|