Learning a Discriminative Feature Network for Semantic Segmentation论文解读
代码链接:https://github.com/lxtGH/dfn_seg
摘要:
我们提出了一个判别特征网络(DFN),它包含两个子网络:平滑网络和边界网络。具体来说,为了处理类内不一致问题,我们特别设计了一个具有通道注意块和全局平均池的平滑网络来选择更明显的区分特征。此外,我们提出了一种边界网络解决不同类之间的问题,通过深度语义边界监督来区分边界的双边特征。提出的平滑网络旨在解决类内不一致的问题。
贡献:
- 我们从一个新的宏观角度重新思考语义分割任务。我们将语义分割视为一项任务,为一个类别的事物分配一个一致的语义标签,而不仅仅是在像素级。
- 提出了一种区分性特征网络来同时解决“类内一致性”和“类间变异”的问题。
- 我们提出了一个平滑的网络来增强与全局上下文和通道注意块的类内一致性。
- 设计了一个具有深度监督的自低级到高级特征的边界网络,以扩大语义边界两侧特征的变化。这也可以细化预测的语义边界。
网络结构
如图,在平滑网络中,我们在网络的顶部添加了全局平均池化层,以获得最强的一致性。然后,我们利用通道注意块来改变通道的权值,以进一步增强一致性。同时,在边界网络中,通过显式的语义边界监督,该网络获得了准确的语义边界,使双边特征更加明显。
现有方法对比:
- Encoder-Decoder:这种类型的体系结构忽略了全局上下文。此外,这种类型的大多数方法只是总结了相邻阶段的特征,而没有考虑到它们的不同表示。这导致了一些不一致的结果。
- Global Context:一些现有方法已经证明了全局平均池的有效性。ParseNet首先在语义分割任务中应用全局平均池。然后PSPNet和Deeplabv3分别将其扩展到空间金字塔池和空间空间金字塔池,在不同的基准测试中取得了很好的性能。然而,为了充分利用金字塔池模块,这两种方法采用空洞卷积进行下采样,耗时且内存较大。
- Attention Module:注意力机制有助于关注我们想要的东西。近年来,注意模块可以关注不同的尺度信息。在这项工作中,我们利用通道注意力(类似于SENet)来选择特征。
平滑网络( Smooth Network):利用高阶段的一致性来指导低阶段的最优预测
- 我们的平滑网络是基于U型结构来捕获多尺度的上下文信息,并使用全局平均池化来捕获全局上下文。此外,我们还提出了一种通道注意块(CAB),它利用高级特征来指导低级特征的逐步选择。
- 类内不一致性问题主要是由于缺乏上下文。因此,我们引入了具有全局平均池的全局上下文。而全局上下文仅具有较高的语义信息,因此需要多尺度的感受野和背景来细化空间信息,即选择更多阶段的特征来预测。故使用ResNet作为一个基识别模型。该模型根据特征图的大小可分为五个阶段。
- 当网络结合相邻阶段的特征时,它只是通过通道来总结这些特征。这个操作忽略了不同阶段的不同一致性。为了弥补这一缺陷,首先嵌入一个全局平均池化层,再通过通道注意力块(Channel attention block)来结合相邻阶段的特征。
Channel attention block(通道注意力块)
结构如下图,图b为注意力分数变量。
通道中含有不同stage的输出,而其重要性不一样,因此我们需要对通道引入注意力机制,来获取不同stage的通道重要性。即需要提取可以判别的特征并抑制不可以判别的特征。
代码:
class CAB(nn.Module):
def __init__(self, in_channels, out_channels):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1, x2 = x
x = torch.cat([x1,x2],dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
Refinement residual block:该块可以加强每个阶段的识别能力。
边界网络(Border Network):辅助损失
-
边界网络试图区分具有相似外观但不同语义标签的相邻补丁。在训练过程中整合语义边界损失来学习区分特征,以扩大“类间的区别”。 -
从低阶段获得准确的边缘信息,从高阶段获得语义信息。 -
该方法是通过使用传统的图像处理方法,类似Canny算法,获取图像的轮廓信息,将此轮廓信息作为边界网络的label,计算该处的损失值。
损失函数:
我们使用深度监督来获得更好的性能,使网络更容易优化。在平滑网络中,我们使用softmox最大损失来监督每个阶段的上采样输出。而我们使用focal loss 来监督边界网络的输出。lambda取值为0.1效果最好。总损失如下:
代码
#!/usr/bin/env python
import torch
import torch.nn as nn
from models.resnet import resnet101
class CAB(nn.Module):
def __init__(self, in_channels, out_channels):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1, x2 = x
x = torch.cat([x1,x2],dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
class RRB(nn.Module):
def __init__(self, in_channels, out_channels):
super(RRB, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
res = self.conv2(x)
res = self.bn(res)
res = self.relu(res)
res = self.conv3(res)
return self.relu(x + res)
class DFN(nn.Module):
def __init__(self, num_class=20):
super(DFN, self).__init__()
self.num_class = num_class
self.resnet_features = resnet101(pretrained=False
)
self.layer0 = nn.Sequential(self.resnet_features.conv1, self.resnet_features.bn1,
self.resnet_features.relu1, self.resnet_features.conv3,
self.resnet_features.bn3, self.resnet_features.relu3
)
self.layer1 = nn.Sequential(self.resnet_features.maxpool, self.resnet_features.layer1)
self.layer2 = self.resnet_features.layer2
self.layer3 = self.resnet_features.layer3
self.layer4 = self.resnet_features.layer4
self.out_conv = nn.Conv2d(2048,self.num_class,kernel_size=1,stride=1)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.cab1 = CAB(self.num_class*2,self.num_class)
self.cab2 = CAB(self.num_class*2,self.num_class)
self.cab3 = CAB(self.num_class*2,self.num_class)
self.cab4 = CAB(self.num_class*2,self.num_class)
self.rrb_d_1 = RRB(256, self.num_class)
self.rrb_d_2 = RRB(512, self.num_class)
self.rrb_d_3 = RRB(1024, self.num_class)
self.rrb_d_4 = RRB(2048, self.num_class)
self.upsample = nn.Upsample(scale_factor=2,mode="bilinear")
self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear")
self.upsample_8 = nn.Upsample(scale_factor=8, mode="bilinear")
self.rrb_u_1 = RRB(self.num_class,self.num_class)
self.rrb_u_2 = RRB(self.num_class,self.num_class)
self.rrb_u_3 = RRB(self.num_class,self.num_class)
self.rrb_u_4 = RRB(self.num_class,self.num_class)
self.rrb_db_1 = RRB(256, self.num_class)
self.rrb_db_2 = RRB(512, self.num_class)
self.rrb_db_3 = RRB(1024, self.num_class)
self.rrb_db_4 = RRB(2048, self.num_class)
self.rrb_trans_1 = RRB(self.num_class,self.num_class)
self.rrb_trans_2 = RRB(self.num_class,self.num_class)
self.rrb_trans_3 = RRB(self.num_class,self.num_class)
def forward(self, x):
f0 = self.layer0(x)
f1 = self.layer1(f0)
f2 = self.layer2(f1)
f3 = self.layer3(f2)
f4 = self.layer4(f3)
res1 = self.rrb_db_1(f1)
res1 = self.rrb_trans_1(res1 + self.upsample(self.rrb_db_2(f2)))
res1 = self.rrb_trans_2(res1 + self.upsample_4(self.rrb_db_3(f3)))
res1 = self.rrb_trans_3(res1 + self.upsample_8(self.rrb_db_4(f4)))
res2 = self.out_conv(f4)
res2 = self.global_pool(res2)
res2 = nn.Upsample(size=f4.size()[2:],mode="nearest")(res2)
f4 = self.rrb_d_4(f4)
res2 = self.cab4([res2,f4])
res2 = self.rrb_u_1(res2)
f3 = self.rrb_d_3(f3)
res2 = self.cab3([self.upsample(res2),f3])
res2 =self.rrb_u_2(res2)
f2 = self.rrb_d_2(f2)
res2 = self.cab2([self.upsample(res2), f2])
res2 =self.rrb_u_3(res2)
f1 = self.rrb_d_1(f1)
res2 = self.cab1([self.upsample(res2), f1])
res2 = self.rrb_u_4(res2)
return res1, res2
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if __name__ == '__main__':
model = DFN(20)
model.freeze_bn()
model.eval()
image = torch.autograd.Variable(torch.randn(1, 3, 512, 512), volatile=True)
res1, res2 = model(image)
print (res1.size(), res2.size())
|