基于CNN的图像/语义分割算法主要有Unet FCN PSPnet DAnet DeepLabV3+,HRnet+OCR等,去年年底基于Transform的各类CV算法(如ViT,Swin等)在分割/分类任务上都表现了相比CNN更为优秀的分割精度。
这里就简单介绍一下基于Swin模块的Unet分割算法:来自慕尼黑工业大学的Swin-Unet
论文:https://arxiv.org/abs/2105.05537 代码:https://github.com/HuCaoFighting/Swin-Unet
首先我们看模型结构: 整个网络结构看起来非常的清楚,可以说基本上就相当于把Unet中的2D卷积换成了Swin模块。对于Swin提出的W-MSA和SW-MSA在前面Swinformer那一期介绍了一下。更详细的还是得看代码。Swin论文那里我认为为了讲故事这块结构写的的有点玄学了。
整体结构和算法部分下面我跟着代码一起详细介绍:
首先是数据增广:
Swin-Unet代码结构比较清晰清爽,整体逻辑非常清晰:
def random_rot_flip(image, label):
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label
def random_rotate(image, label):
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label
图像增广方面就用了两个,一个是图像和label同步进行随机翻转,一个是图像和label进行正负20度随机旋转
其他的就很常规了: 首先写了一个Synapse_dataset类,通过继承torch的Dataset类,复写Dataset中的__len__和__getitem__方法,其中__getitem__主要是读图像和label的numpy数组,利用上面的图像增广做同步矩阵变换之后转换成pytorch的torch.tensor后喂入模型,__getitem__主要是读到图像时同步为图像和label做对应的操作。这里我为了博客轻量把具体实现代码去掉了。想看的同学可以去看这块源码。很简单。
class Synapse_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
def __len__(self):
return len(self.sample_list)
def __getitem__(self, idx):
return sample
最后老办法喂入torch的dataloader后通过epoch的for循环同步读取训练数据的image和label的tensor:
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
worker_init_fn=worker_init_fn)
数据预处理说完了。接下来介绍网络实现步骤:
首先是transform的PatchEmbed结构:
整个结构基本上就是照搬Swin的PatchEmbed方法,直接通过一个2D卷积 表征位置信息(事实上目前很多基于Transform的算法都这么干的)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
熟悉Unet结构的同学应该清楚整个Unet核心其实就三部分:
编码头: 对图像特征进行聚合,同时下采样,WH减半,channel同步增加(由于Swin输入多少输出多少,所以下采样功能是通过torch的linear层实现的)
解码头: 将图像上采样要原图大小方便进行像素点分类
跳连接: 网络层越深得到的特征图,有着更大的感受野,浅层卷积关注纹理特征,深层网络关注本质的那种特征,通过跳连接可以使特征向量同时具有深层和表层特征(cat方法),由于图像在上采样过程(CNN的图像分割一般通过2Dconv+双线性插值进行上采样)本身不增加新的信息,但是每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征的一个找回。
由于SwinBlock相比CNN比较特殊,它的输入和输出是一样的,下采样主要
上采样: 作者尝试了双线性插值/转置卷积/Patch expand三种方法,通过对比实验证明了其有效性: Patch expand方法其实很简单,首先通过一个线性层把长采样到两倍,然后通过torch.view()通道数变成1/4,wh各增加2倍。cat后刚好和encode对齐
class PatchExpand(nn.Module):
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
self.norm = norm_layer(dim // dim_scale)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
x = x.view(B,-1,C//4)
x= self.norm(x)
return x
跳连接个数: 如下表显示跳连接确实是work的
损失函数部分:
Swin-Unet的损失函数有任何的改进,是0.4的交叉熵+0.6的dice-loss构成
outputs = model(image_batch)
loss_ce = ce_loss(outputs, label_batch[:].long())
loss_dice = dice_loss(outputs, label_batch, softmax=True)
loss = 0.4 * loss_ce + 0.6 * loss_dice
效果:
Swin-Unet凭借Swin中MSA强大特征提取能力。相比一众算法展现了sota的效果:
总结:Swin-Unet只是在各个特征提取模块将Unet的2D卷积换成了Swin结构,在Swin结构和Unet结构上基本没有改变,损失函数也没有做变化。再次说明了Swin模块的强大特征提取能力(感觉创新不太够啊,不过代码挺清爽的)
|