【图像分割】2021-SegFormer NeurIPS
论文题目: SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
论文地址:https://arxiv.org/abs/2105.15203v3
代码地址: https://github.com/NVlabs/SegFormer
论文团队:香港大学, 南京大学, NVIDIA, Caltech
SegFormer论文详解,2021CVPR收录,将Transformer与语义分割相结合的作品,
1. 简介
1.1 简介
- 2021可以说是分割算法爆发的一年,首先
ViT 通过引入transform将ADE20K mIOU精度第一次刷到50%,超过了之前HRnet+OCR 效果, - 然后再是
Swin 屠榜各大视觉任务,在分类,语义分割和实例分割都做到了SOTA,斩获ICCV2021的bset paper, - 然后Segformer有凭借对transform再次深层次优化,在拿到更高精度的基础之上还大大提升了模型的实时性。
动机来源有:SETR中使用VIT作为backbone提取的特征较为单一,PE限制预测的多样性,传统CNN的Decoder来恢复特征过程较为复杂。主要提出多层次的Transformer-Encoder和MLP-Decoder,性能达到SOTA。
1.2 解决的问题
SegFormer是一个将transformer与轻量级多层感知器(MLP)解码器统一起来的语义分割框架。SegFormer的优势在于:
- SegFormer设计了一个新颖的分级结构transformer编码器,输出多尺度特征。它不需要位置编码,从而避免了位置编码的插值(当测试分辨率与训练分辨率不同时,会导致性能下降)。
- SegFormer避免了复杂的解码器。提出的MLP解码器从不同的层聚合信息,从而结合局部关注和全局关注来呈现强大的表示。作者展示了这种简单和轻量级的设计是有效分割transformer的关键。
2. 网络
2.1 架构
1) 总体结构
这种架构类似于ResNet,Swin-Transformer。经过一个阶段,
-
编码器:一个分层的Transformer编码器,用于生成高分辨率的粗特征和低分辨率的细特征 由Transformer blocks*N 组成一个单独的阶段(stage)。 一个Transformer block 由3个部分组成
- Overlap Patch Merging
- Mix-FFN
- Effcient Self-Atten
-
解码器:一个轻量级的All-MLP解码器,融合这些多级特征,产生最终的语义分割掩码。
2) 编码器配置
下面是SegFormer的编码器的具体配置
3) 分层结构
与只能生成单分辨率特征图的ViT不同,该模块的目标是对给定输入图像生成类似cnn的多级特征 。这些特征提供了高分辨率的粗特征和低分辨率的细粒度特征,通常可以提高语义分割的性能。
更准确地说,给定一个分辨率为
H
×
W
×
3
H\times W\times 3
H×W×3。我们进行patch合并,得到一个分辨率为
(
H
2
i
+
1
×
W
2
i
+
1
×
C
)
(\frac{H}{2^{i+1}}\times \frac{W}{2^{i+1}}\times C)
(2i+1H?×2i+1W?×C)的层次特征图
F
i
F_i
Fi?,其中
i
∈
{
1
,
2
,
3
,
4
}
i\in\{1,2,3,4\}
i∈{1,2,3,4}。
举个例子,经过一个阶段
F
1
=
(
H
4
×
W
4
×
C
)
→
F
2
=
(
H
8
×
W
8
×
C
)
F_1=(\frac{H}{4}\times \frac{W}{4}\times C) \to F_2=(\frac{H}{8}\times \frac{W}{8}\times C)
F1?=(4H?×4W?×C)→F2?=(8H?×8W?×C)
2.2 分层的Transformer解码器
编码器由3个部分组成,首先讲一下,下采样模块
1) Overlap Patch Merging
对于一个映像patch,ViT中使用的patch合并过程将一个
N
×
N
×
3
N\times N\times 3
N×N×3的图像统一成
1
×
1
×
C
1\times 1\times C
1×1×C向量。这可以很容易地扩展到将一个
2
×
2
×
C
i
2\times 2\times C_i
2×2×Ci?特征路径统一到一个
1
×
1
×
C
i
+
1
1\times 1\times C_{i+1}
1×1×Ci+1?向量中,以获得分层特征映射。
使用此方法,可以将层次结构特性从
F
1
=
(
H
4
×
W
4
×
C
)
→
F
2
=
(
H
8
×
W
8
×
C
)
F_1=(\frac{H}{4}\times \frac{W}{4}\times C) \to F_2=(\frac{H}{8}\times \frac{W}{8}\times C)
F1?=(4H?×4W?×C)→F2?=(8H?×8W?×C)。然后迭代层次结构中的任何其他特性映射。这个过程最初的设计是为了结合不重叠的图像或特征块。因此,它不能保持这些斑块周围的局部连续性。相反,我们使用重叠补丁合并过程。因此,论文作者分别通过设置K,S,P为(7,4,3)(3,2,1)的卷积来进行重叠的Patch merging。其中,K为kernel,S为Stride,P为padding。
说的这么花里胡哨的,其实作用就是和MaxPooling一样 ,起到下采样 的效果。使得特征图变成原来的
1
2
\frac{1}{2}
21?
2) Efficient Self-Attention
编码器的主要计算瓶颈是自注意层。在原来的多头自注意过程中,每个头
K
,
Q
,
V
K,Q,V
K,Q,V都有相同的维数
N
×
C
N\times C
N×C,其中
N
=
H
×
W
N=H\times W
N=H×W为序列的长度,估计自注意为:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
d
h
e
a
d
)
V
Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_{head}}})V
Attention(Q,K,V)=Softmax(dhead?
?QKT?)V 这个过程的计算复杂度是
O
(
N
2
)
O(N^2)
O(N2),这对于大分辨率的图像来说是巨大的。
论文作者认为,网络的计算量主要体现在自注意力机制层上。为了降低网路整体的计算复杂度,论文作者在自注意力机制的基础上,添加的缩放因子
R
R
R,来降低每一个自注意力机制模块的计算复杂度。
K
^
=
R
e
s
h
a
p
e
(
N
R
,
C
?
R
)
(
K
)
K
=
L
i
n
e
a
r
(
C
?
R
,
C
)
(
K
^
)
\begin{aligned} \hat{K}&=Reshape(\frac{N}{R},C\cdot R)(K) \\ K&=Linear(C\cdot R,C)(\hat{K}) \end{aligned}
K^K?=Reshape(RN?,C?R)(K)=Linear(C?R,C)(K^)? 其中第一步将
K
K
K的形状由
N
×
C
N\times C
N×C转变为
N
R
×
(
C
?
R
)
\frac{N}{R}\times(C\cdot R)
RN?×(C?R),
第二步又将
K
K
K的形状由
N
R
×
(
C
?
R
)
\frac{N}{R}\times(C\cdot R)
RN?×(C?R)转变为
N
R
×
C
\frac{N}{R}\times C
RN?×C。因此,计算复杂度就由
O
(
N
2
)
O(N^2)
O(N2)降至
O
(
N
2
R
)
O(\frac{N^2}{R})
O(RN2?)。在作者给出的参数中,阶段1到阶段4的
R
R
R分别为
[
64
,
16
,
4
,
1
]
[64,16,4,1]
[64,16,4,1]
3) Mix-FFN
VIT使用位置编码PE(Position Encoder)来插入位置信息,但是插入的PE的分辨率是固定的,这就导致如果训练图像和测试图像分辨率不同的话,需要对PE进行插值操作,这会导致精度下降。
为了解决这个问题CPVT(Conditional positional encodings for vision transformers. arXiv, 2021)使用了3X3的卷积和PE一起实现了data-driver PE。
引入了一个 Mix-FFN,考虑了padding对位置信息的影响,直接在 FFN (feed-forward network)中使用 一个3x3 的卷积,MiX-FFN可以表示如下:
X
o
u
t
=
M
L
P
(
G
E
L
U
(
C
o
n
v
3
×
3
(
M
L
P
(
X
i
n
)
)
)
)
+
X
i
n
X_{out}=MLP(GELU(Conv_{3\times3}(MLP(X_{in}))))+X_{in}
Xout?=MLP(GELU(Conv3×3?(MLP(Xin?))))+Xin? 其中
X
i
n
X_{in}
Xin?是从self-attention中输出的feature。Mix-FFN混合了一个
3
?
3
3*3
3?3的卷积和MLP在每一个FFN中。即根据上式可以知道MiX-FFN的顺序为:输入经过MLP,再使用
C
o
n
v
3
×
3
Conv_{3\times3}
Conv3×3?操作,正在经过一个GELU激活函数,再通过MLP操作,最后将输出和原始输入值进行叠加操作,作为MiX-FFN的总输出。
在实验中作者展示了
3
?
3
3*3
3?3的卷积可以为transformer提供PE。作者还是用了深度可以分离卷积提高效率,减少参数。
2.3 轻量级MLP解码器
SegFormer集成了一个轻量级解码器,只包含MLP层。实现这种简单解码器的关键是,SegFormer的分级Transformer编码器比传统CNN编码器具有更大的有效接受域(ERF)。
SegFormer所提出的全mlp译码器由四个主要步骤组成。
- 来自MiT编码器的多级特性通过MLP层来统一通道维度。
- 特征被上采样到1/4并连接在一起。
- 采用MLP层融合级联特征
F
F
F
- 另一个MLP层采用融合的
H
4
×
W
4
×
N
c
l
s
\frac{H}{4}\times \frac{W}{4}\times N_{cls}
4H?×4W?×Ncls?分辨率特征来预测分割掩码
M
M
M,其中表示类别数目
解码器可以表述为:
F
^
i
=
L
i
n
e
a
r
(
C
i
,
C
)
(
F
i
)
,
?
i
F
^
i
=
U
p
s
a
m
p
l
e
(
W
4
×
W
4
)
(
F
^
i
)
,
?
i
F
=
L
i
n
e
a
r
(
4
C
,
C
)
(
C
o
n
c
a
t
(
F
^
i
)
)
,
?
i
M
=
L
i
n
e
a
r
(
C
,
N
c
l
s
)
(
F
)
\begin{aligned} \hat{F}_i&=Linear(C_i,C)(F_i),\forall i \\ \hat{F}_i&=Upsample(\frac{W}{4}\times \frac{W}{4})(\hat{F}_i),\forall i \\ F&=Linear(4C,C)(Concat(\hat{F}_i)),\forall i \\ M&=Linear(C,N_{cls})(F) \end{aligned}
F^i?F^i?FM?=Linear(Ci?,C)(Fi?),?i=Upsample(4W?×4W?)(F^i?),?i=Linear(4C,C)(Concat(F^i?)),?i=Linear(C,Ncls?)(F)?
2.4 有效接受视野(ERF)
这个部分是 用来证明 解码器是非常有效的
对于语义分割,保持较大的接受域以包含上下文信息一直是一个中心问题。SegFormer使用有效接受域(ERF)作为一个工具包来可视化和解释为什么All-MLP译码器设计在TransFormer上如此有效。在下图中可视化了DeepLabv3+和SegFormer的四个编码器阶段和解码器头的ERF:
从上图中可以观察到:
- DeepLabv3+的ERF即使在最深层的Stage4也相对较小。
- SegFormer编码器自然产生局部注意,类似于较低阶段的卷积,同时能够输出高度非局部注意,有效捕获Stage4的上下文。
- 如放大Patch所示,MLP头部的ERF(蓝框)与Stage4(红框)不同,其非局部注意力和局部注意力显著增强。
CNN的接受域有限,需要借助语境模块扩大接受域,但不可避免地使网络变复杂。All-MLP译码器设计得益于transformer中的非局部注意力,并在不复杂的情况下导致更大的接受域。然而,同样的译码器设计在CNN主干上并不能很好地工作,因为整体的接受域是在Stage4的有限域的上限。
更重要的是,All-MLP译码器设计本质上利用了Transformer诱导的特性,同时产生高度局部和非局部关注。通过统一它们,All-MLP译码器通过添加一些参数来呈现互补和强大的表示。这是推动我们设计的另一个关键原因。
3. 代码
下面展示的SegFormer 的Bo版本。其他版本,可以自己调整
from math import sqrt
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
def exists(val):
return val is not None
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
class DsConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride=1, bias=True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride,
bias=bias),
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias)
)
def forward(self, x):
return self.net(x)
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))
class EfficientSelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads,
reduction_ratio
):
"""
自注意力层
Args:
dim: 输入维度
heads: 注意力头数
reduction_ratio: 缩放因子
"""
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.to_q = nn.Conv2d(dim, dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride=reduction_ratio, bias=False)
self.to_out = nn.Conv2d(dim, dim, 1, bias=False)
def forward(self, x):
h, w = x.shape[-2:]
heads = self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h=heads, x=h, y=w)
return self.to_out(out)
class MixFeedForward(nn.Module):
def __init__(
self,
*,
dim,
expansion_factor
):
super().__init__()
hidden_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1),
DsConv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1)
)
def forward(self, x):
return self.net(x)
class MiT(nn.Module):
def __init__(
self,
*,
channels,
dims,
heads,
ff_expansion,
reduction_ratio,
num_layers
):
"""
Mix Transformer Encoder
Args:
channels:
dims:
heads:
ff_expansion:
reduction_ratio:
num_layers:
"""
super().__init__()
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
dims = (channels, *dims)
dim_pairs = list(zip(dims[:-1], dims[1:]))
self.stages = nn.ModuleList([])
for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in \
zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
layers = nn.ModuleList([])
for _ in range(num_layers):
layers.append(nn.ModuleList([
PreNorm(dim_out, EfficientSelfAttention(dim=dim_out,
heads=heads,
reduction_ratio=reduction_ratio)),
PreNorm(dim_out, MixFeedForward(dim=dim_out, expansion_factor=ff_expansion)),
]))
self.stages.append(nn.ModuleList([
get_overlap_patches,
overlap_patch_embed,
layers
]))
def forward(self, x, return_layer_outputs=False):
h, w = x.shape[-2:]
layer_outputs = []
for (get_overlap_patches, overlap_embed, layers) in self.stages:
x = get_overlap_patches(x)
num_patches = x.shape[-1]
ratio = int(sqrt((h * w) / num_patches))
x = rearrange(x, 'b c (h w) -> b c h w', h=h // ratio)
x = overlap_embed(x)
for (attn, ff) in layers:
x = attn(x) + x
x = ff(x) + x
layer_outputs.append(x)
ret = x if not return_layer_outputs else layer_outputs
return ret
class Segformer(nn.Module):
def __init__(
self,
*,
dims=(32, 64, 160, 256),
heads=(1, 2, 5, 8),
ff_expansion=(8, 8, 4, 4),
reduction_ratio=(8, 4, 2, 1),
num_layers=2,
channels=3,
decoder_dim=256,
num_classes=19
):
"""
Args:
dims: 4个阶段,出来的通道数
heads: 每个阶段,使用的注意力头数目
ff_expansion: mix-ffn 中 3*3卷积的扩张倍率
reduction_ratio: 自注意力层缩放因子
num_layers: 每个transformer blocks块重复的次数
channels: 输入通道数,一般为3
decoder_dim: 解码器维度 。 作用: 编码器的特征图统一 上采样--> decoder_dim 维度
num_classes: 分类数目
"""
super().__init__()
dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth=4),
(dims, heads, ff_expansion, reduction_ratio,
num_layers))
assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio,
num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
self.mit = MiT(
channels=channels,
dims=dims,
heads=heads,
ff_expansion=ff_expansion,
reduction_ratio=reduction_ratio,
num_layers=num_layers
)
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1),
nn.Upsample(scale_factor=2 ** (i))
) for i, dim in enumerate(dims)])
self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1),
nn.Conv2d(decoder_dim, num_classes, 1),
)
def forward(self, x):
layer_outputs = self.mit(x, return_layer_outputs=True)
"""
torch.Size([1, 32, 56, 56])
torch.Size([1, 64, 28, 28])
torch.Size([1, 160, 14, 14])
torch.Size([1, 256, 7, 7])
"""
fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
fused = torch.cat(fused, dim=1)
fused = self.to_segmentation(fused)
return F.interpolate(fused, size=x.shape[2:], mode='bilinear', align_corners=False)
if __name__ == '__main__':
x = torch.randn(size=(1, 3, 224, 224))
model = Segformer()
print(model)
from thop import profile
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input,))
print("flops:{:.3f}G".format(flops / 1e9))
print("params:{:.3f}M".format(params / 1e6))
参考资料
https://blog.csdn.net/weixin_43610114/article/details/125000614
https://blog.csdn.net/weixin_44579633/article/details/121081763
https://blog.csdn.net/qq_39333636/article/details/124334384
语义分割之SegFormer分享_xuzz_498100208的博客-CSDN博客
论文笔记——Segformer: 一种基于Transformer的语义分割方法 - 知乎 (zhihu.com)
手把手教你使用Segformer训练自己的数据_中科哥哥的博客-CSDN博客
|