数据增强是深度训练过程中一个重要的步骤,2d的数据增强现在已经比较成熟,官方也有自己的数据增强函数。然而,3d数据增强的代码却不是很多,这里分析一下我所使用到的3d医学数据的数据增强方法。
3D医学数据
在医学图像处理领域,常见的两种医学图像格式是nii 和 DICOM 文件,在我的项目中,我首先实现了dicom数据到nii数据的转化。 nii格式我们可以使用ITK-SNAP 软件来查看。如图: 在这一次的代码中,数据的读取是nii格式,但是无论什么格式并不影响之后的数据增强操作,因为都是先读取为array格式然后再操作。
数据增强库
这里介绍两种比较常见的3d医学数据增强库:
- Volumentations 3D
from volumentations import *
基于python的3d数据增强库,所以在tf和pytorch上都可以使用。
- TorchIO
基于pytorch的3d数据增强库,不仅包含了数据增强操作还有很多医学图像的处理方法。
import torchio as tio
这里我主要介绍第一种库的使用,同时,一些函数是我自己写的。
数据可视化
如果我们想要观察3d医学图像,通常是使用ITK-SNAP 软件打开nii文件。所以当我们对原始图像进行完数据增强操作,将其保存为nii文件,然后使用itk-snap查看,array转nii代码如下:
from scipy import ndimage
import nibabel as nib
new_image = nib.Nifti1Image(my_arr, np.eye(4))
nib.save(new_image, 'nifti.nii.gz')
但是,每进行一次数据增强,就保存一个新的nii文件并打开查看过于繁琐,于是我自己写了一个可视化函数draw_oct,他可以画出每个维度上中轴线切片图像,就像itk-snap一样,选取的是每个维度的中心。 其中,volume是要可视化的数据,type_volume是数据类型,np对应与使用Volumentations 3D的数据,tensor对一应的是使用TorchIO的数据。最后一个canal_first是通道维度的位置,我们在进行深度学习时使用的虽然是3d医学数据,但是网络的输入往往要求是四维的,除了高度深度宽度,还有一个通道数,通常是1或3,代表着灰度图片或者彩色图片。它的位置在tf或torch中也是不尽相同,有时在第一个有时在最后一个,我们使用函数时应该注意这一点。
import matplotlib.pyplot as plt
def draw_oct(volume, type_volume = 'np',canal_first = False):
if type_volume == 'np':
if canal_first == False:
print("taille du volume = %s (%s)"%(volume.shape,type_volume))
slice_h_n, slice_d_n , slice_w_n = int(volume.shape[0]/2),int(volume.shape[1]/2),int(volume.shape[2]/2)
slice_h = volume[slice_h_n,:,:,:]
slice_d = volume[:,slice_d_n,:,:]
slice_w = volume[:,:,slice_w_n,:]
slice_h = Image.fromarray(np.squeeze(slice_h))
slice_d = Image.fromarray(np.squeeze(slice_d))
slice_w = Image.fromarray(np.squeeze(slice_w))
plt.figure(figsize=(21,7))
plt.subplot(1, 3, 1)
plt.imshow(slice_h)
plt.title(slice_h.size)
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(slice_d)
plt.title(slice_d.size)
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(slice_w)
plt.title(slice_w.size)
plt.axis('off')
if canal_first == True:
print("taille du volume = %s (%s)"%(volume.shape,type_volume))
slice_h_n, slice_d_n , slice_w_n = int(volume.shape[1]/2),int(volume.shape[2]/2),int(volume.shape[3]/2)
slice_h = volume[:,slice_h_n,:,:]
slice_d = volume[:,:,slice_d_n,:]
slice_w = volume[:,:,:,slice_w_n]
slice_h = Image.fromarray(np.squeeze(slice_h))
slice_d = Image.fromarray(np.squeeze(slice_d))
slice_w = Image.fromarray(np.squeeze(slice_w))
plt.figure(figsize=(21,7))
plt.subplot(1, 3, 1)
plt.imshow(slice_h)
plt.title(slice_h.size)
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(slice_d)
plt.title(slice_d.size)
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(slice_w)
plt.title(slice_w.size)
plt.axis('off')
if type_volume == 'tensor':
print("taille du volume = %s (%s)"%(volume.shape,type_volume))
slice_h_n, slice_d_n , slice_w_n = int(volume.shape[0]/2),int(volume.shape[1]/2),int(volume.shape[2]/2)
slice_h = volume[slice_h_n,:,:,:].numpy()
slice_d = volume[:,slice_d_n,:,:].numpy()
slice_w = volume[:,:,slice_w_n,:].numpy()
slice_h = Image.fromarray(np.squeeze(slice_h))
slice_d = Image.fromarray(np.squeeze(slice_d))
slice_w = Image.fromarray(np.squeeze(slice_w))
plt.figure(figsize=(21,7))
plt.subplot(1, 3, 1)
plt.imshow(slice_h)
plt.title(slice_h.size)
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(slice_d)
plt.title(slice_d.size)
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(slice_w)
plt.title(slice_w.size)
plt.axis('off')
原始图片
画一下原始图片的样子,以便和使用数据增强后的图片做对比:
import numpy as np
from scipy import ndimage
import nibabel as nib
image_structure_name = "structure_volume_25.nii.gz"
volume_structure_ori = nib.load(image_structure_name)
volume_structure_ori = volume_structure_ori.get_data().astype(np.float32)
draw_oct(volume_structure_ori)
数据增强操作
我会分析一下每种数据增强的操作,其中,Volumentations 3D库不同函数中的p代表着执行这个操作的概率。如果你想进行多种数据增强操作,请使用Compose函数。
Resize
def resize_image(image, image_size):
"""
Resizes an image to the neural network input size.
:param image: input image
:return: the resized image
"""
epsilon = 1e-6
image = ndimage.zoom(image, zoom=[(float(image_size[i]) / image.shape[i] + epsilon)
for i in range(4)], order=1)
trim = [(image.shape[i] - image_size[i]) // 2 for i in range(3)]
return image[trim[0]:trim[0] + image_size[0], trim[1]:trim[1] + image_size[1],
trim[2]:trim[2] + image_size[2], :]
volume_structure_224 = resize_image(volume_structure_ori, (224, 224, 224, 1))
draw_oct(volume_structure_224)
Randomcorp
patch_size = (224,224,224)
aug_randomcrop = Compose([
RandomCrop(patch_size)
], p=1.0)
data_structure = {'image': volume_structure_ori}
aug_data = aug_randomcrop(**data_structure)
volume_randomcrop = aug_data['image']
draw_oct(volume_randomcrop)
Pad
默认使用constant模式的补0操作,如果你想改变pad的参数,请参阅volumentations中原函数的参数选项。
aug_pad = Compose([PadIfNeeded(shape = (300,300,300), border_mode= "constant")
], p=1.0)
start = time.time()
data_structure = {'image': volume_structure_224}
aug_data = aug_pad(**data_structure)
volume_pad = aug_data['image']
end = time.time()
print('time for transform =',end-start)
draw_oct(volume_pad)
Normalize
aug_normalize = Compose([ Normalize()]
, p=1.0)
data_structure = {'image': volume_structure_ori}
aug_data = aug_normalize(**data_structure)
volume_normalize = aug_data['image']
draw_oct(volume_normalize*255)
Crop
不同于randomcrop操作,crop是我自定义的遍历整个体积的函数,由于在深度上的crop操作会影响图片的病例识别,所以我只在宽度和高度上进行遍历crop。volume_size= (300, 384, 300, 1) , crop_size= (224, 384, 224, 1) 所以应该产生四个子体积。
print('volume_size=',volume_structure_ori.shape)
crop_size=(224,384,224,1)
print('crop_size=',crop_size)
h = volume_structure_ori.shape[0]
w = volume_structure_ori.shape[2]
crop_h = crop_size[0]
crop_w = crop_size[2]
nombre_h= math.ceil(h/crop_h)
nombre_w= math.ceil(w/crop_w)
nombre_volume = 0
for i_h in range(nombre_h):
for i_w in range(nombre_w):
start_h = i_h * crop_h
start_w = i_w * crop_w
if i_h == nombre_h-1:
start_h = h - crop_h
if i_w == nombre_w-1:
start_w = w - crop_w
sub_volume = volume_structure_ori[start_h:start_h+crop_h, :, start_w:start_w+crop_w, :]
draw_oct(sub_volume)
nombre_volume = nombre_volume + 1
print('sub volumes nombre =',nombre_volume)
Flip
aug_flip = Compose([ Flip(0, p=0.5),Flip(2, p=0.5)
], p=1.0)
data_structure = {'image': volume_structure_ori}
aug_data = aug_flip(**data_structure)
volume_flip = aug_data['image']
draw_oct(volume_flip)
Rotate
aug_rotate = Compose([
Rotate((-15, 15), (0, 0), (0, 0), p=0.5),
], p=1.0)
data_structure = {'image': volume_structure_ori}
aug_data = aug_rotate(**data_structure)
volume_rotate = aug_data['image']
draw_oct(volume_rotate)
ElasticTransform
这个操作比较费时间,可能导致训练时间过长,酌情使用
aug_Elastic = Compose([
ElasticTransform((0, 0.25), interpolation=2, p=0.1),
], p=1.0)
start = time.time()
data_structure = {'image': volume_structure_ori}
aug_data = aug_Elastic(**data_structure)
volume_Elastic = aug_data['image']
end = time.time()
print('time for transform =',end-start)
draw_oct(volume_Elastic)
RandomRotate90
aug_RandomRotate90 = Compose([ RandomRotate90((1, 2), p=0.5)
], p=1.0)
start = time.time()
data_structure = {'image': volume_structure_ori}
aug_data = aug_RandomRotate90(**data_structure)
volume_RandomRotate90 = aug_data['image']
end = time.time()
print('time for transform =',end-start)
draw_oct(volume_RandomRotate90)
GaussianNoise
aug_GaussianNoise = Compose([
GaussianNoise(var_limit=(0, 5), p=0.5)
], p=1.0)
start = time.time()
data_structure = {'image': volume_structure_ori}
aug_data = aug_GaussianNoise(**data_structure)
volume_GaussianNoise = aug_data['image']
end = time.time()
print('time for transform =',end-start)
draw_oct(volume_GaussianNoise)
RandomGamma
aug_RandomGamma = Compose([
RandomGamma(gamma_limit=(0.5, 1.5), p=0.5)
], p=1.0)
start = time.time()
data_structure = {'image': volume_structure_ori}
aug_data = aug_RandomGamma(**data_structure)
volume_RandomGamma = aug_data['image']
end = time.time()
print('time for transform =',end-start)
draw_oct(volume_RandomGamma)
GridDropout
aug_GridDropout = Compose([GridDropout(ratio = 0.5,unit_size_min = 50,
unit_size_max = 60, holes_number_x = 3, holes_number_y = 2 ,holes_number_z = 2,p = 0.5)
], p=1.0)
start = time.time()
data_structure = {'image': volume_structure_ori}
aug_data = aug_GridDropout(**data_structure)
volume_GridDropout = aug_data['image']
end = time.time()
print('time for transform =',end-start)
draw_oct(volume_GridDropout)
CutoutAbs
随机从原体积中用黑色覆盖一块体积。
def CutoutAbs(volume,ratio=0.5):
length_w = int(ratio*volume.shape[0])
length_d = int(ratio*volume.shape[1])
length_h = int(ratio*volume.shape[2])
start_w = random.randint(0,volume.shape[0])
start_d = random.randint(0,volume.shape[1])
start_h = random.randint(0,volume.shape[2])
end_w = (start_w + length_w) if (start_w + length_w) < volume.shape[0] else (volume.shape[0]-1)
end_d = (start_d + length_d) if (start_d + length_d) < volume.shape[1] else (volume.shape[1]-1)
end_h = (start_h + length_h) if (start_h + length_h) < volume.shape[2] else (volume.shape[2]-1)
new_volume = volume.copy()
del volume
new_volume[start_w:end_w,start_d:end_d, start_h:end_h,:] = 0
return new_volume
volume_structure_CutoutAbs = CutoutAbs(volume_structure_ori,0.5)
draw_oct(volume_structure_CutoutAbs)
Random blur (torchio)
最后在这里简单提一下torchio库数据增强的用法。其输入有好几种形式,无论是array或者是tensor都可以,不过要求都是通道优先。所以要做维度转化。
volume_structure_canalfirst = volume_structure_ori.transpose(3, 0, 1, 2)
blur = tio.RandomBlur()
blurred = blur(volume_structure_canalfirst)
draw_oct(blurred,canal_first = True)
随机数据增强
最后,在fixmatch方法中,强数据增强的方法包括:首先从数据增强pool中随机选取n个数据增强操作,然后使用CutoutAbs。我们看看3d体积上怎么实现。
def my_augment_pool():
augs = [
Flip(0, p=0.5),
Flip(2, p=0.5),
Rotate((-15, 15), (0, 0), (0, 0), p=0.5),
ElasticTransform((0, 0.25), interpolation=2, p=0.5),
RandomRotate90((1, 2), p=0.5),
GaussianNoise(var_limit=(0, 5), p=0.5),
RandomGamma(gamma_limit=(0.5, 1.5), p=0.5),
GridDropout(ratio = 0.5,unit_size_min = 50,
unit_size_max = 60, holes_number_x = 3, holes_number_y = 2 ,holes_number_z = 2,p = 0.5)
]
return augs
augs = my_augment_pool()
def RandAugmentMC(volume, augs, n = 2, crop = False ,patch_size = (224,224,224)):
ops = random.choices(augs, k=n)
aug_list = []
for i in range(n):
aug_list.append(ops[i])
if crop == True:
aug_list.append(RandomCrop(patch_size))
aug_strongly = Compose(aug_list, p=1.0)
new_volume = volume.copy()
del volume
start = time.time()
data_structure = {'image': new_volume}
aug_data = aug_strongly(**data_structure)
new_volume = aug_data['image']
end = time.time()
print('time for transform =',end-start)
new_volume = CutoutAbs(new_volume,0.5)
return new_volume
volume_stronglyaug = RandAugmentMC(volume_structure_ori,augs,n = 2, crop=True)
draw_oct(volume_stronglyaug)
|