IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> U-Net论文详解 -> 正文阅读

[人工智能]U-Net论文详解

U-Net论文详解

UNet算法Pytorch实现:https://github.com/codecat0/CV/tree/main/Semantic_Segmentation/UNet

U-Net结构由一个用于捕获上下文信息的压缩路径和一个支持精确定位的对称扩展路径构成。实验结果表明可以从很少的图像进行端到端的训练,并在ISBI挑战上优于先前最优的方法(滑动窗口卷积网络),并获得了冠军

1. 背景介绍

卷积网络的典型应用是分类任务,其中图像的输出是一个单一的类标签。然而在许多视觉任务中,特别是生物医学图像处理中,期望的输出应该包含定位,即给每一个像素点分配一个类标签。

于是滑动窗口卷积网络通过提供像素点周围的局部区域来预测每个像素的类别标签。但是这样的方法存在两个缺点:

  1. 速度特别慢,网络必须为每一个窗口单元单独运行,并且窗口单元重合而导致大量冗余
  2. 在定位精度和上下文信息之间的权衡。大的窗口单元需要更多的max pooling层,这会降低精度;而小的窗口单元捕获的上下文信息较少。

于是本文提出了U-Net网络

2. U-Net网络架构

在这里插入图片描述

网络是一个经典的全卷积网络。网络的输入是一张572x572经过镜像操作的图像。为了使得每次下采样后特征图的尺寸为偶数。
在这里插入图片描述

网络的左侧为压缩路径,由4个block构成,每个block由2个未padding的卷积和一个最大池化构成,其中每次卷积特征图的尺寸为减小2,最大池化后会缩小一半。

现在大部分采用same padding的卷积,这样就不用对输入进行镜像操作,而且在拼接压缩路径与对应的扩展路径也不用进行裁剪,而且裁剪会使得特征图不对称

网络的右侧为扩展路径,同样由4个block构成,每个block开始之前通过反卷积将特征图的尺寸扩大一倍,然后与压缩路径对应的特征图拼接,由于采用未padding的卷积,左侧压缩路径的特征图的尺寸比右侧扩展路径的特征图的大,所以需要先进行裁剪,使其大小相同,然后拼接,然后经过两次未padding的卷积进一步提取特征

最后根据自己的任务,输出对应大小的预测特征图

现在大部分采用双线性插值代替反卷积,而且效果会更好

3. 数据增强

我们主要通过平移和旋转不变性以及灰度值的变化来增强模型的鲁棒性,特别地,任意的弹性形变对训练非常有帮助

4. Pytorch实现

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x_pooled = self.pool(x)
        return x, x_pooled


class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.up_sample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x_prev, x):
        x = self.up_sample(x)
        x_shape = x.shape[2:]
        x_prev_shape = x.shape[2:]
        h_diff = x_prev_shape[0] - x_shape[0]
        w_diff = x_prev_shape[1] - x_shape[1]
        # padding
        x_tmp = torch.zeros(x_prev.shape).to(x.device)
        x_tmp[:, :, h_diff//2: h_diff+x_shape[0], w_diff//2: x_shape[1]] = x
        x = torch.cat([x_prev, x_tmp], dim=1)
        x = self.block1(x)
        x = self.block2(x)
        return x



class UNet(nn.Module):
    # https://arxiv.org/abs/1505.04597
    def __init__(self, num_classes=2):
        super(UNet, self).__init__()

        self.down_sample1 = Encoder(in_channels=3, out_channels=64)
        self.down_sample2 = Encoder(in_channels=64, out_channels=128)
        self.down_sample3 = Encoder(in_channels=128, out_channels=256)
        self.down_sample4 = Encoder(in_channels=256, out_channels=512)

        self.mid1 = nn.Sequential(
            nn.Conv2d(512, 1024, 3, bias=False),
            nn.ReLU(inplace=True)
        )
        self.mid2 = nn.Sequential(
            nn.Conv2d(1024, 1024, 3, bias=False),
            nn.ReLU(inplace=True)
        )

        self.up_sample1 = Decoder(in_channels=1024, out_channels=512)
        self.up_sample2 = Decoder(in_channels=512, out_channels=256)
        self.up_sample3 = Decoder(in_channels=256, out_channels=128)
        self.up_sample4 = Decoder(in_channels=128, out_channels=64)

        self.classifier = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        x1, x = self.down_sample1(x)
        x2, x = self.down_sample2(x)
        x3, x = self.down_sample3(x)
        x4, x = self.down_sample4(x)

        x = self.mid1(x)
        x = self.mid2(x)

        x = self.up_sample1(x4, x)
        x = self.up_sample2(x3, x)
        x = self.up_sample3(x2, x)
        x = self.up_sample4(x1, x)

        x = self.classifier(x)
        return x



if __name__ == '__main__':
    input = torch.rand(1, 3, 384, 384)
    model = UNet(2)
    out = model(input)
    print(out.shape)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-28 22:55:34  更:2021-12-28 22:55:38 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 20:39:08-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码