IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 附代码 DFNet :Learning a Discriminative Feature Network for Semantic Segmentation -> 正文阅读

[人工智能]附代码 DFNet :Learning a Discriminative Feature Network for Semantic Segmentation

Learning a Discriminative Feature Network for Semantic Segmentation论文解读

代码链接:https://github.com/lxtGH/dfn_seg

摘要:

我们提出了一个判别特征网络(DFN),它包含两个子网络:平滑网络和边界网络。具体来说,为了处理类内不一致问题,我们特别设计了一个具有通道注意块和全局平均池的平滑网络来选择更明显的区分特征。此外,我们提出了一种边界网络解决不同类之间的问题,通过深度语义边界监督来区分边界的双边特征。提出的平滑网络旨在解决类内不一致的问题。

贡献

  1. 我们从一个新的宏观角度重新思考语义分割任务。我们将语义分割视为一项任务,为一个类别的事物分配一个一致的语义标签,而不仅仅是在像素级。
  2. 提出了一种区分性特征网络来同时解决“类内一致性”和“类间变异”的问题。
  3. 我们提出了一个平滑的网络来增强与全局上下文和通道注意块的类内一致性。
  4. 设计了一个具有深度监督的自低级到高级特征的边界网络,以扩大语义边界两侧特征的变化。这也可以细化预测的语义边界。

网络结构

如图,在平滑网络中,我们在网络的顶部添加了全局平均池化层,以获得最强的一致性。然后,我们利用通道注意块来改变通道的权值,以进一步增强一致性。同时,在边界网络中,通过显式的语义边界监督,该网络获得了准确的语义边界,使双边特征更加明显。
在这里插入图片描述

现有方法对比:

  • Encoder-Decoder:这种类型的体系结构忽略了全局上下文。此外,这种类型的大多数方法只是总结了相邻阶段的特征,而没有考虑到它们的不同表示。这导致了一些不一致的结果。
  • Global Context:一些现有方法已经证明了全局平均池的有效性。ParseNet首先在语义分割任务中应用全局平均池。然后PSPNet和Deeplabv3分别将其扩展到空间金字塔池和空间空间金字塔池,在不同的基准测试中取得了很好的性能。然而,为了充分利用金字塔池模块,这两种方法采用空洞卷积进行下采样,耗时且内存较大。
  • Attention Module:注意力机制有助于关注我们想要的东西。近年来,注意模块可以关注不同的尺度信息。在这项工作中,我们利用通道注意力(类似于SENet)来选择特征。

平滑网络( Smooth Network):利用高阶段的一致性来指导低阶段的最优预测

  • 我们的平滑网络是基于U型结构来捕获多尺度的上下文信息,并使用全局平均池化来捕获全局上下文。此外,我们还提出了一种通道注意块(CAB),它利用高级特征来指导低级特征的逐步选择。
  • 类内不一致性问题主要是由于缺乏上下文。因此,我们引入了具有全局平均池的全局上下文。而全局上下文仅具有较高的语义信息,因此需要多尺度的感受野和背景来细化空间信息,即选择更多阶段的特征来预测。故使用ResNet作为一个基识别模型。该模型根据特征图的大小可分为五个阶段
  • 当网络结合相邻阶段的特征时,它只是通过通道来总结这些特征。这个操作忽略了不同阶段的不同一致性。为了弥补这一缺陷,首先嵌入一个全局平均池化层,再通过通道注意力块(Channel attention block)来结合相邻阶段的特征。

Channel attention block(通道注意力块)

结构如下图,图b为注意力分数变量。
在这里插入图片描述
在这里插入图片描述

通道中含有不同stage的输出,而其重要性不一样,因此我们需要对通道引入注意力机制,来获取不同stage的通道重要性。即需要提取可以判别的特征并抑制不可以判别的特征。

代码

class CAB(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(CAB, self).__init__()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.sigmod = nn.Sigmoid()

    def forward(self, x):
        x1, x2 = x  # high, low
        x = torch.cat([x1,x2],dim=1)
        x = self.global_pooling(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmod(x)
        x2 = x * x2#低级特征获得语义信息
        res = x2 + x1
        return res

Refinement residual block:该块可以加强每个阶段的识别能力。

在这里插入图片描述

边界网络(Border Network):辅助损失

  • 边界网络试图区分具有相似外观但不同语义标签的相邻补丁。在训练过程中整合语义边界损失来学习区分特征,以扩大“类间的区别”。

  • 从低阶段获得准确的边缘信息,从高阶段获得语义信息。

  • 该方法是通过使用传统的图像处理方法,类似Canny算法,获取图像的轮廓信息,将此轮廓信息作为边界网络的label,计算该处的损失值。
    在这里插入图片描述

损失函数:

我们使用深度监督来获得更好的性能,使网络更容易优化。在平滑网络中,我们使用softmox最大损失来监督每个阶段的上采样输出。而我们使用focal loss 来监督边界网络的输出。lambda取值为0.1效果最好。总损失如下:
在这里插入图片描述

代码

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Author: Xiangtai(lxtpku@pku.edu.cn)
# Implementation of Paper Learning a Discriminative Feature Network for Semantic Segmentation (CVPR2018)(face_plus_plus)


import torch
import torch.nn as nn
from models.resnet import resnet101


class CAB(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(CAB, self).__init__()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.sigmod = nn.Sigmoid()

    def forward(self, x):
        x1, x2 = x  # high, low
        x = torch.cat([x1,x2],dim=1)
        x = self.global_pooling(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmod(x)
        x2 = x * x2#低级特征获得语义信息
        res = x2 + x1
        return res

class RRB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(RRB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        res  = self.conv2(x)
        res = self.bn(res)
        res = self.relu(res)
        res = self.conv3(res)
        return self.relu(x + res)


class DFN(nn.Module):
    def __init__(self, num_class=20):
        super(DFN, self).__init__()
        self.num_class = num_class
        self.resnet_features = resnet101(pretrained=False
)
        self.layer0 = nn.Sequential(self.resnet_features.conv1, self.resnet_features.bn1,
                                    self.resnet_features.relu1, self.resnet_features.conv3,
                                    self.resnet_features.bn3, self.resnet_features.relu3
                                    )
        self.layer1 = nn.Sequential(self.resnet_features.maxpool, self.resnet_features.layer1)
        self.layer2 = self.resnet_features.layer2
        self.layer3 = self.resnet_features.layer3
        self.layer4 = self.resnet_features.layer4

        # this is for smooth network
        self.out_conv = nn.Conv2d(2048,self.num_class,kernel_size=1,stride=1)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.cab1 = CAB(self.num_class*2,self.num_class)
        self.cab2 = CAB(self.num_class*2,self.num_class)
        self.cab3 = CAB(self.num_class*2,self.num_class)
        self.cab4 = CAB(self.num_class*2,self.num_class)

        self.rrb_d_1 = RRB(256, self.num_class)
        self.rrb_d_2 = RRB(512, self.num_class)
        self.rrb_d_3 = RRB(1024, self.num_class)
        self.rrb_d_4 = RRB(2048, self.num_class)

        self.upsample = nn.Upsample(scale_factor=2,mode="bilinear")
        self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear")
        self.upsample_8 = nn.Upsample(scale_factor=8, mode="bilinear")

        self.rrb_u_1 = RRB(self.num_class,self.num_class)
        self.rrb_u_2 = RRB(self.num_class,self.num_class)
        self.rrb_u_3 = RRB(self.num_class,self.num_class)
        self.rrb_u_4 = RRB(self.num_class,self.num_class)


        ## this is for boarder net work
        self.rrb_db_1 = RRB(256, self.num_class)
        self.rrb_db_2 = RRB(512, self.num_class)
        self.rrb_db_3 = RRB(1024, self.num_class)
        self.rrb_db_4 = RRB(2048, self.num_class)

        self.rrb_trans_1 = RRB(self.num_class,self.num_class)
        self.rrb_trans_2 = RRB(self.num_class,self.num_class)
        self.rrb_trans_3 = RRB(self.num_class,self.num_class)

    def forward(self, x):
        # suppose input = x , if x 512
        f0 = self.layer0(x)  # 256
        f1 = self.layer1(f0)  # 128
        f2 = self.layer2(f1)  # 64
        f3 = self.layer3(f2)  # 32
        f4 = self.layer4(f3)  # 16

        # for border network
        res1 = self.rrb_db_1(f1)
        res1 = self.rrb_trans_1(res1 + self.upsample(self.rrb_db_2(f2)))
        res1 = self.rrb_trans_2(res1 + self.upsample_4(self.rrb_db_3(f3)))
        res1 = self.rrb_trans_3(res1 + self.upsample_8(self.rrb_db_4(f4)))
        # print (res1.size())
        # for smooth network
        res2 = self.out_conv(f4)
        res2 = self.global_pool(res2)  #
        res2 = nn.Upsample(size=f4.size()[2:],mode="nearest")(res2)#由于RRB没有降低尺寸,因此需要加入upsample将res2的尺寸拉回f4

        f4 = self.rrb_d_4(f4)
        res2 = self.cab4([res2,f4])
        res2 = self.rrb_u_1(res2)

        f3 = self.rrb_d_3(f3)
        res2 = self.cab3([self.upsample(res2),f3])
        res2 =self.rrb_u_2(res2)

        f2 = self.rrb_d_2(f2)
        res2 = self.cab2([self.upsample(res2), f2])
        res2 =self.rrb_u_3(res2)

        f1 = self.rrb_d_1(f1)
        res2 = self.cab1([self.upsample(res2), f1])
        res2 = self.rrb_u_4(res2)

        return res1, res2

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

if __name__ == '__main__':
    model = DFN(20)
    model.freeze_bn()
    model.eval()
    image = torch.autograd.Variable(torch.randn(1, 3, 512, 512), volatile=True)
    res1, res2 = model(image)
    print (res1.size(), res2.size())

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-22 20:35:25  更:2022-02-22 20:37:55 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 11:22:27-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码