IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> SwinUNet2022 -> 正文阅读

[人工智能]SwinUNet2022

1. 概述

本文提出了一种以 S w i n Swin Swin变压器层为基本块的 S U N e t SUNet SUNet恢复模型,并将其应用于 U N e t UNet UNet架构中进行图像去噪。

2. 背景

图像恢复是一种重要的低级图像处理方法,可以提高其在目标检测、图像分割和图像分类等高级视觉任务中的性能。在一般的恢复任务中,一个被损坏的图像Y可以表示为:
Y = D ( X ) + n (1) Y=D(X)+n \tag 1 Y=D(X)+n(1)
其中 X X X是一个干净的图像, D ( ? ) D(\cdot) D(?表示退化函数, n n n表示加性噪声。一些常见的恢复任务是去噪、去模糊和去阻塞。

2.1 CNN局限性

虽然大多数基于卷积神经网络(CNN)的方法都取得了良好的性能,但卷积层存在几个问题。首先,卷积核与图像的内容无关(无法与图像内容相适应)。使用相同的卷积核来恢复不同的图像区域可能不是最好的解决方案。其次,由于卷积核可以看作是一个小块,其中获取的特征是局部信息,换句话说,当我们进行长期依赖建模时,全局信息就会丢失。

3. 结构

3.1 UNet

目前,UNet由于具有层次特征映射来获得丰富的多尺度上下文特征,是许多图像处理应用中著名的架构。此外,它利用编码器和解码器之间的跳跃连接来增强图像的重建过程。UNet被广泛应用于许多计算机视觉任务,如分割、恢复[。此外,它还有各种改进的版本,如Res-UNet,Dense-UNet,Attention-UNet[和Non-local-UNet。由于具有较强的自适应骨干网,UNet可以很容易地应用于不同的提取块,以提高性能。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MpLyfXEs-1647427670075)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316140830395.png)]

3.2 Swin Transformer

Transformer模型在自然语言处理(NLP)领域取得了成功,并具有良好的竞争性能,特别是在图像分类方面。然而,直接使用Transformer到视觉任务的两个主要问题是:

(1)图像和序列之间的尺度差异很大。由于Transformer需要参数量为一维序列参数的平方倍,所以存在长序列建模的缺陷。

(2)Transformer不擅长解决实例分割等密集预测任务,即像素级任务。然而,Swin Transfomer通过滑动窗口解决了上述问题,降低了参数,并在许多像素级视觉任务中实现了最先进的性能。

3.3 SUNet

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mK2P8qBG-1647427670077)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316141320770.png)]

所提出的Swin Transformer UNet(SUNet)的架构是基于图像分割模型,如上图所示。SUNet由三个模块组成:

(1)浅层特征提取;

(2)UNet特征提取;

(3)重建模块

浅层特征提取模块:

对于有噪声的输入图像 Y ∈ R H × W × 3 Y∈R^{H×W×3} YRH×W×3,其中H,W为失真图像的分辨率。我们使用单个3×3卷积层 M S F E ( ? ) M_{SFE}(\cdot) MSFE?(?)获取输入图像的颜色或纹理等低频信息。浅特征 F s h a l l o w ∈ R H × W × C F_{shallow}∈R^{H×W×C} Fshallow?RH×W×C可以表示为:
F s h a l l o w = M S F E ( Y ) (2) F_{shallow}=M_{SFE}(Y) \tag 2 Fshallow?=MSFE?(Y)(2)
其中,C是浅层特征的通道数,在后一个实验部分中,我们都设置为96.

UNet 特征提取网络:

然后,将浅层特征 F s h a l l o w F_{shallow} Fshallow?输入UNet特征提取 M U F E ( ? ) M_{UFE}(\cdot) MUFE?(?),UNet用来提取高级、多尺度深度特征 F d e e p ∈ R H × W × C F_{deep}∈R^{H×W×C} Fdeep?RH×W×C
F d e e p = M U F E ( F s h a l l o w ) (3) F_{deep}=M_{UFE}(F_{shallow}) \tag 3 Fdeep?=MUFE?(Fshallow?)(3)
其中, M U F E ( ? ) M_{UFE}(\cdot) MUFE?(?)是带有Swin变压器块的UNet架构,它在单个块中包含8个Swin Transformer层来代替卷积。Swin Transformer Block(STB)和Swin Transformer Layer(STL)将在下一小节中进行详细说明。

重建层:

最后,我们仍然使用3×3卷积 M R ( ? ) M_{R}(\cdot) MR?(?)从深度特征 F d e e p F_{deep} Fdeep?中生成无噪声图像 X ^ ∈ R H × W × 3 \hat{X}∈R^{H×W×3} X^RH×W×3,其公式为:
X ^ = M R ( F d e e p ) (4) \hat{X}=M_{R}(F_{deep}) \tag 4 X^=MR?(Fdeep?)(4)
注意, X ^ \hat{X} X^是以噪声图像 Y Y Y作为SUNet的输入得到的,其中 X X X是(1)中Y图像的原高分率图像。

3.4 Loss function

我们优化了我们的SUNet端到端与规则的 L 1 L1 L1像素损失的图像去噪:
L d e n o i s e = ∣ ∣ X ^ ? X ∣ ∣ 1 (5) L_{denoise}=||\hat{X}-X||_1 \tag 5 Ldenoise?=X^?X1?(5)

3.5 Swin Transformer Block

在UNet提取模块中,我们使用STB来代替传统的卷积层,如下图所示。STL是基于NLP中的原始Transformer Layer。STL的数量总是2的倍数,其中一个是window multi-head-self-attention(W-MSA),另一个是shifted-window multi-head self-attention(SW-MSA)。

STL的公式描述:
f ^ L = W ? M S A ( L N ( f L ? 1 ) ) + f L ? 1 f L = M L P ( L N ( f ^ L ) ) + f ^ L f ^ L + 1 = S W ? M S A ( L N ( f L ) ) + f L f L + 1 = M L P ( L N ( f ^ L + 1 ) ) + f ^ L + 1 (6) \hat{f}^L=W-MSA(LN(f^{L-1}))+f^{L-1} \\ f^L=MLP(LN(\hat{f}^L))+\hat{f}^L \\ \hat{f}^{L+1}=SW-MSA(LN(f^{L}))+f^{L} \\ f^{L+1}=MLP(LN(\hat{f}^{L+1}))+\hat{f}^{L+1} \tag 6 f^?L=W?MSA(LN(fL?1))+fL?1fL=MLP(LN(f^?L))+f^?Lf^?L+1=SW?MSA(LN(fL))+fLfL+1=MLP(LN(f^?L+1))+f^?L+1(6)
其中, L N ( ? ) LN(\cdot) LN(?)表示为层归一化, M L P MLP MLP是多层感知器,它具有两个完全连接的层,同时后面跟一个线性单位(GELU)激活函数。

3.6 Resizing module

由于UNet具有不同的特征图尺度,因此调整大小的模块(例如,下样本和上样本)是必要的。在我们的SUNet中,我们使用 p a t c h ? m e r g i n g patch\ merging patch?merging,并提出 d u a l ? u p ? s a m p l e dual\ up-sample dual?up?sample分别作为下样本和上样本模块。

3.6.1 patch merging

对于降采样模块,该文将每一组2×2相邻斑块的输入特征连接起来,然后使用线性层获得指定的输出通道特征。我们也可以把这看作是做卷积操作的第一步,也就是展开输入的特征映射。

3.6.2 Dual up-sample

对于上采样,原始的Swin-UNet采用patch expanding方法,等价于上采样模块中的转置卷积。然而,转置卷积很容易面对块效应。在这里,我们提出了一个新的模块,称为双上样本,它包括两种现有的上样本方法(即Bilinear和PixelShuffle),以防止棋盘式的artifacts。所提出的上采样模块的体系结构如下图所示。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-joxathTz-1647427670077)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316145653720.png)]

4. 结果

评估指标:

为了进行定量比较,我们考虑了峰值信噪比(PSNR)和结构相似度(SSIM)指数度量。

训练集:

采用DIV2K作为训练集,一共有900张高清图片。我们对每个训练图像随机裁剪100个大小为 256 × 256 256×256 256×256的斑块,并对 800 800 800张训练图像从 σ = 5 σ=5 σ=5 σ = 50 σ=50 σ=50 p a t c h patch patch中随机添加AWGN噪声。至于验证集,我们直接使用包含100张图像的测试集,并添加具有三种不同噪声水平的AWGN, σ = 10 、 σ = 30 和 σ = 50 σ=10、σ=30和σ=50 σ=10σ=30σ=50

测试集:

对于评估,我们选择了CBSD68数据集,它有68张彩色图像,分辨率为768×512,以及Kodak24张数据集,由24张图像组成,图像大小为321×481。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WdJUyCFm-1647427670078)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316170249843.png)]

在表1中,我们对去噪图像进行了客观的质量评价,并观察到以下三件事:

(1)该文的SUNet具有竞争性的SSIM值,因为Swin-Transformer是基于全局信息(q,k,v可以提取全局信息),使得去噪图像拥有更多的视觉效果。

(2)与基于unet的方法(DHDN、RDUNet)相比,该文所提出的SUNet模型中参数(↓60%)和FLOPs(↓3%)较少,在PSNR和SSIM上仍保持良好的得分

(3)与基于cnn的方法(DnCNN,IrCNN,FFDNet)相比,该文得到了其中最好的PSNR和SSIM结果,以及几乎相同的FLOPs。虽然该文的模型的参数最多(99M),但它是由于自注意操作不能共享核的权值造成的。

4. 总结

  • 提出了一种基于图像分割的双unet模型的双变换网络进行图像去噪。
  • 该文提出了一种双上样本块结构,它包括亚像素方法和双线性上样本方法,以防止棋盘伪影。实验结果表明,该方法优于转置卷积的原始上样本。
  • 该文的模型是第一个结合Swin变压器和UNet进行去噪的模型。

Reference: Swin-unet: Unet-like pure transformer for medical image segmentation

5. 某些代码的理解

5.1 window attention中的相对位置编码

代码位置位于: ./model/SUNet_detail.py 中的89行左右

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KhTSbrYb-1647503570456)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316205927221.png)]

我这里展示了一个例子:

>>> import torch
>>> coords_h=torch.arange(3)
>>> coords_w=torch.arange(3)
>>> coords=torch.stack(torch.meshgrid([coords_h,coords_w]))
>>> coords
tensor([[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]])
>>> coords_flatten=torch.flatten(coords,1)
>>> coords_flatten
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])
>>> relative_coords=coords_flatten[:,:,None]-coords_flatten[:,None,:]
>>> relative_coords
tensor([[[ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0]],

        [[ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0]]])

可以看到 r e l a t i v e _ c o o r d s relative\_coords relative_coords的第一维是2,分别对应x轴和y轴方向(或者高,宽的方向)。剩下两维呢是9*9。SUNet中间使用了一些SwinIR的结构,在SwinIR中是有shift-window的,在这里,我设置的window size为3。又因为再做attention的时候,我们把每一个window中的像素点当作一个token,那么最终的attention map( q ? v q * v q?v)的最后两维就是 w i n d o w _ w i d t h ? w i n d o w _ h e i g h t window\_width \cdot window\_height window_width?window_height

下面再来看具体的物理意义, 3 × 3 3 \times 3 3×3的window一共有9个数值,第一维度分别代表这两个轴;第一个矩阵中,第一行分别代表着第一个数值(一共有9个)在某个轴上相对于其他位置的距离(在第一个数值的右边为负,左边为正),第二个矩阵类似。不同window的相对位置是一样的。

>>> relative_coords
tensor([[[ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0]],

        [[ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0]]])
>>> relative_coords = relative_coords.permute(1, 2, 0).contiguous()
>>> relative_coords
tensor([[[ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-2,  0],
         [-2, -1],
         [-2, -2]],

        [[ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-2,  1],
         [-2,  0],
         [-2, -1]],

        [[ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-2,  2],
         [-2,  1],
         [-2,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2]],

        [[ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1]],

        [[ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
         [-1,  1],
         [-1,  0]],

        [[ 2,  0],
         [ 2, -1],
         [ 2, -2],
         [ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2]],

        [[ 2,  1],
         [ 2,  0],
         [ 2, -1],
         [ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1]],

        [[ 2,  2],
         [ 2,  1],
         [ 2,  0],
         [ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0]]])

下面的代码就是将相对位置坐标全部加上 w i n d o w _ s i z e ? 1 window\_size-1 window_size?1,使得全部为正值:

>>> relative_coords[:, :, 0] += window_size[0] - 1
>>> relative_coords[:, :, 1] += window_size[1] - 1
>>> relative_coords
tensor([[[2, 2],
         [2, 1],
         [2, 0],
         [1, 2],
         [1, 1],
         [1, 0],
         [0, 2],
         [0, 1],
         [0, 0]],

        [[2, 3],
         [2, 2],
         [2, 1],
         [1, 3],
         [1, 2],
         [1, 1],
         [0, 3],
         [0, 2],
         [0, 1]],

        [[2, 4],
         [2, 3],
         [2, 2],
         [1, 4],
         [1, 3],
         [1, 2],
         [0, 4],
         [0, 3],
         [0, 2]],

        [[3, 2],
         [3, 1],
         [3, 0],
         [2, 2],
         [2, 1],
         [2, 0],
         [1, 2],
         [1, 1],
         [1, 0]],

        [[3, 3],
         [3, 2],
         [3, 1],
         [2, 3],
         [2, 2],
         [2, 1],
         [1, 3],
         [1, 2],
         [1, 1]],

        [[3, 4],
         [3, 3],
         [3, 2],
         [2, 4],
         [2, 3],
         [2, 2],
         [1, 4],
         [1, 3],
         [1, 2]],

        [[4, 2],
         [4, 1],
         [4, 0],
         [3, 2],
         [3, 1],
         [3, 0],
         [2, 2],
         [2, 1],
         [2, 0]],

        [[4, 3],
         [4, 2],
         [4, 1],
         [3, 3],
         [3, 2],
         [3, 1],
         [2, 3],
         [2, 2],
         [2, 1]],

        [[4, 4],
         [4, 3],
         [4, 2],
         [3, 4],
         [3, 3],
         [3, 2],
         [2, 4],
         [2, 3],
         [2, 2]]])

下面是将横纵坐标的相对位置加起来:

>>> relative_position_index = relative_coords.sum(-1)
>>> relative_position_index
tensor([[4, 3, 2, 3, 2, 1, 2, 1, 0],
        [5, 4, 3, 4, 3, 2, 3, 2, 1],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [5, 4, 3, 4, 3, 2, 3, 2, 1],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [7, 6, 5, 6, 5, 4, 5, 4, 3],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [7, 6, 5, 6, 5, 4, 5, 4, 3],
        [8, 7, 6, 7, 6, 5, 6, 5, 4]])

下面为定义bias,随机初始化,但是在网络的迭代训练中,是会被反向传播的

 # define a parameter table of relative position bias
 self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] 			- 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
 trunc_normal_(self.relative_position_bias_table, std=.02)

一个window里面明明只有9个数值,为什么定义bias时,矩阵的维度为== 25 , n u m h e a d s 25, num_heads 25,numh?eads==呢?**

这是因为上面我们加了 w i n d o w _ s i z e ? 1 window\_size -1 window_size?1:不加之前最大值为 w i n d o w _ s i z e ? 1 window\_size-1 window_size?1,后面在加上 w i n d o w _ s i z e ? 1 window\_size-1 window_size?1,此时,最大值为 2 × w i n d o w _ s i z e ? 2 2\times window\_size -2 2×window_size?2,再算上零,一共有 2 × w i n d o w _ s i z e ? 1 2\times window\_size-1 2×window_size?1。所以再初始化bias的时候,我觉得维度为 2 × w i n d o w _ s i z e ? 1 2\times window\_size-1 2×window_size?1

就够了,不知道为什么要定义 ( 2 × w i n d o w _ s i z e ? 1 ) × 2 × w i n d o w _ s i z e ? 1 (2\times window\_size-1)\times 2\times window\_size-1 (2×window_size?1)×2×window_size?1呢?

每个window是独立attention的,所以每个window的relative_position_bias都是一样的。

下面就是加MASK操作了,不再赘述。

5.2 Shifted-window-attention

代码位于 S w i n T r a n s f o r m e r B l o c k SwinTransformerBlock SwinTransformerBlock中。

s h i f t _ s i z e > 0 shift\_size>0 shift_size>0时:

# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1
# nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size)  
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, 			 	float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
>>> img_mask = torch.zeros((1, 3, 3, 1))
>>> h_slices=(slice(0,-3),slice(-3,1),slice(-1,None))
>>> h_slices
(slice(0, -3, None), slice(-3, 1, None), slice(-1, None, None))
>>> w_slices=(slice(0,-3),slice(-3,-1),slice(-1,None))
>>> w_slices
(slice(0, -3, None), slice(-3, -1, None), slice(-1, None, None))
>>> cnt = 0
>>> for h in h_slices:
...     for w in w_slices:
...             img_mask[:,h,w,:]=cnt
...             cnt+=1
...
>>> img_mask
tensor([[[[4.],
          [4.],
          [5.]],

         [[4.],
          [4.],
          [5.]],

         [[7.],
          [7.],
          [8.]]]])
# nW, window_size, window_size, 1    [1,3,3,1]
mask_windows = window_partition(img_mask, self.window_size)
# 经过window_partition后,没有发生变化
>>> mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
tensor([[4., 4., 5., 4., 4., 5., 7., 7., 8.]])
>>> attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
>>> attn_mask.size()
torch.Size([1, 9, 9])
>>> attn_mask
tensor([[[ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [-1., -1.,  0., -1., -1.,  0.,  2.,  2.,  3.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [-1., -1.,  0., -1., -1.,  0.,  2.,  2.,  3.],
         [-3., -3., -2., -3., -3., -2.,  0.,  0.,  1.],
         [-3., -3., -2., -3., -3., -2.,  0.,  0.,  1.],
         [-4., -4., -3., -4., -4., -3., -1., -1.,  0.]]])
>>> attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
>>> attn_mask
tensor([[[   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100., -100., -100.,    0.]]])

我们再来看forward函数对x的shift操作:

>>> x=torch.arange(0,9)
>>> x
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> x=x.unsqueeze(0)+x.unsqueeze(1)
>>> x
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9],
        [ 2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 4,  5,  6,  7,  8,  9, 10, 11, 12],
        [ 5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 7,  8,  9, 10, 11, 12, 13, 14, 15],
        [ 8,  9, 10, 11, 12, 13, 14, 15, 16]])
>>> x=x.unsqueeze(2).unsqueeze(0)
>>> x.size()
torch.Size([1, 9, 9, 1])
>>> shifted_x = torch.roll(x, shifts=(-1, -1), dims=(1, 2))
>>> xx=shifted_x.squeeze(3).squeeze(0)
>>> xx
tensor([[ 2,  3,  4,  5,  6,  7,  8,  9,  1],
        [ 3,  4,  5,  6,  7,  8,  9, 10,  2],
        [ 4,  5,  6,  7,  8,  9, 10, 11,  3],
        [ 5,  6,  7,  8,  9, 10, 11, 12,  4],
        [ 6,  7,  8,  9, 10, 11, 12, 13,  5],
        [ 7,  8,  9, 10, 11, 12, 13, 14,  6],
        [ 8,  9, 10, 11, 12, 13, 14, 15,  7],
        [ 9, 10, 11, 12, 13, 14, 15, 16,  8],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  0]])

那为什么要加mask呢,它是由一个假设的,假设各个window之间不相关,各个window单独做attention。其中我们以 H = 9 , W = 9 , w i n d o w _ s i z e = 3 , s h i f t _ s i z e = 1 H=9,W=9, window\_size=3, shift\_size=1 H=9,W=9,window_size=3,shift_size=1为例。

图片参考:SWin Transformer

没有进行shift的window划分图:
在这里插入图片描述

再forward中,是会对输入的x进行shift操作的:
在这里插入图片描述

shift后的操作:

在这里插入图片描述

其中,上图黑线代表原来的边界。每个彩色框代表经过shift操作后的window划分,可以看到每个彩色框内部黑线位置是一样的;黑线是window的边界。

>>> attn_mask
tensor([[[   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100., -100., -100.,    0.]]])

我们举个例子说明:

以第一个彩色框为例,第一行代表第一个元素是否可以看到对应位置的元素(0代表看得到,-100代表看不到)。当两个像素点位于不同的window之中(黑线)就是看不到,就赋给一个负数,后面再做softmax,对应权重就会非常小。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-COD9Kl1R-1647503570459)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220317132729985.png)]

5.3 Dual Up Sample

使用PixelShuffle和bilinear合起来的特征作为输出。

5.4 Absolute position embedding

# absolute position embedding
if self.ape:
    self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, 			                embed_dim)) 
    trunc_normal_(self.absolute_pos_embed, std=.02)

大概位置再SUNet的635行左右。

那为什么要加Absolute position embedding呢?

是因为再5.2 window attention中呢,是有一个relative position的,但是relative position作用范围仅仅是在一个window里面,即每个window相同位置上的relative position都是一样的。所以需要absolute position。

5.5 DownSampling

下采样是通过Patch Embedding实现的。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-17 22:08:21  更:2022-03-17 22:09:04 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 14:29:21-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码