IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> PANet:基于金字塔注意力网络的图像超分辨率重建(Pytorch实现) -> 正文阅读

[人工智能]PANet:基于金字塔注意力网络的图像超分辨率重建(Pytorch实现)

PANet:基于金字塔注意力网络的图像超分辨率重建

[!] 为了提高代码的可读性,本文模型的具体实现与原文具有一定区别,因此会造成性能上的差异



1.相关资料


2.简介

  • PANet(Pyramid Attention with Simple Network Backbones)是一种基于图像恢复金字塔注意力模块的图像修复模型,它能够从多尺度特征金字塔种提取到长距离与短距离的特征关系。
  • 受降采样能够有效减少压缩伪影等图像噪声的启发,作者所提出的金字塔利用不同采样倍数的特征图来相互传递注意力信号,以更灵活的方式来借用不同特征尺寸之间的“干净”信息。
  • 作者只在一个简单的前馈链接网络中加入了一个金字塔注意力模块,就在绝大多数图像修复任务中达到了SOTA。(这样看来模块确实牛逼)

3.模型结构

直接上图
模型结构

  • 图上面部分就是传说中的金字塔注意力模块,图下面部分就是PANet的结构(这个结构和SRResNet怪像的,可以参考我的相关文章:SRResNetSRGAN
  • 金字塔注意力模块的结构分为两个部分:金字塔采样环节S-A Attention。金字塔采样环节就是简单的降采样处理,根据源代码来看,作者使用的是双二次下采样的方法。
  • S-A Attention的结构参考了NLP中最经典的注意力机制结构,即构建了Q,K,V三种特征图来捕获图像在不同尺寸中的信息。与其他注意力机制不同的是,S-A Attention将注意力机制中的按元素相乘环节改成将Q和K特征图作为卷积核(即图中浅蓝色特征层出来的两个特征图)来与V特征图进行卷积/反卷积操作。

4.项目实践

在这里我会一步一步教大家做一个能够成功运行的PANet,完整的代码也会很快推出。

4.1 准备工作

  • 笔者使用的工作环境如下所示:

    系统:Windows 10
    CPU:Intel Core i9-10850K
    GPU:GeForce RTX 3090
    
  • 实现代码所需要准备的库为:

    Pytorch
    OpenCV
    Numpy
    Torchvision
    
  • 本文使用的是COCO 2017数据集,其中包含了123,403张照片,大家可以根据自己的需要来使用自己的数据集。

4.2 具体实现

为了方便阅读,部分代码已标注中文注释,而且全部放进了一个代码文件中

  • 完整版代码支持重新打开代码自动恢复到上次训练的功能,只需要关注笔者即可获得:传送门

4.2.1 导入项目所需库

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset,SubsetRandomSampler
import torch.optim as optim
from torchvision import utils as vutils
from torchvision.utils import save_image

import os
import cv2
import random as ra
import numpy as np
import math

4.2.2 构建数据集

class PreprocessDataset(Dataset):
   def __init__(self,path,size = 96):
       super().__init__()
       self.size = size  #高清图像的尺寸,这里默认为96x96

       self.allImgs = list()
       for root,dirs,files in os.walk(path):
           self.allImgs = [os.path.join(root,file) for file in files] #获取图像的地址

           
   def __len__(self):
       return len(self.allImgs)
   
   def __getitem__(self,index):
       img = self.allImgs[index]
       img = cv2.imread(img)            
       img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
       height,width,_ = img.shape
 
       xStart = ra.randint(0,width-self.size-1)
       yStart = ra.randint(0,height-self.size-1)

       img = img[yStart:self.size + yStart,xStart:self.size + xStart,:]  #随机裁剪图像
           
       if ra.random() > 0.5:
           img = cv2.flip(img,1)  #有50%几率反转图像
       
       hr = torch.tensor(np.transpose(img,(2,0,1)))/255.0
       hr = (hr - 0.5)/0.5  #像素标准化
       lr = F.max_pool2d(hr,2) #使用最大池化来获得下采样图片
       
       return hr,lr
  • 构建完数据集类后,我们可以很方便地构建对应的Dataloader。在这里我只构建了训练集,并没有构建测试集。

    path = '你的数据集文件路径'
    
    
    dataset = PreprocessDataset(path,size = 96)
    trainData = DataLoader(dataset,batch_size = 32,num_workers = 4,shuffle = True)
    

4.2.3 构建网络模型

# 特征金字塔部分

  • 这里直接改进了原作者的金字塔注意力模块代码,因此代码风格会与其他部分有一定差异。
    def extract_image_patches(images, ksizes, strides, rates, padding='same'):
        """
        Extract patches from images and put them in the C output dimension.
        :param padding:
        :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
        :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
         each dimension of images
        :param strides: [stride_rows, stride_cols]
        :param rates: [dilation_rows, dilation_cols]
        :return: A Tensor
        """
        assert len(images.size()) == 4
        assert padding in ['same', 'valid']
        batch_size, channel, height, width = images.size()
        
        if padding == 'same':
            images = same_padding(images, ksizes, strides, rates)
        elif padding == 'valid':
            pass
        else:
            raise NotImplementedError('Unsupported padding type: {}.\
                    Only "same" or "valid" are supported.'.format(padding))
    
        unfold = torch.nn.Unfold(kernel_size=ksizes,
                                 dilation=rates,
                                 padding=0,
                                 stride=strides)
        patches = unfold(images)
        return patches  # [N, C*k*k, L], L is the total number of such blocks
    
    def reduce_sum(x, axis=None, keepdim=False):
        if not axis:
            axis = range(len(x.shape))
        for i in sorted(axis, reverse=True):
            x = torch.sum(x, dim=i, keepdim=keepdim)
        return x
    
    def same_padding(images, ksizes, strides, rates):
        assert len(images.size()) == 4
        batch_size, channel, rows, cols = images.size()
        out_rows = (rows + strides[0] - 1) // strides[0]
        out_cols = (cols + strides[1] - 1) // strides[1]
        effective_k_row = (ksizes[0] - 1) * rates[0] + 1
        effective_k_col = (ksizes[1] - 1) * rates[1] + 1
        padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
        padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
        # Pad the input
        padding_top = int(padding_rows / 2.)
        padding_left = int(padding_cols / 2.)
        padding_bottom = padding_rows - padding_top
        padding_right = padding_cols - padding_left
        paddings = (padding_left, padding_right, padding_top, padding_bottom)
        images = torch.nn.ZeroPad2d(paddings)(images)
        return images
    
    def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
        return nn.Conv2d(
            in_channels, out_channels, kernel_size,
            padding=(kernel_size//2),stride=stride, bias=bias)
    
    class BasicBlock(nn.Sequential):
        def __init__(
            self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
            bn=False, act=nn.PReLU()):
    
            m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
            if bn:
                m.append(nn.BatchNorm2d(out_channels))
            if act is not None:
                m.append(act)
    
            super(BasicBlock, self).__init__(*m)
    
    class PyramidAttention(nn.Module):
        def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, conv=default_conv):
            super(PyramidAttention, self).__init__()
            self.ksize = ksize
            self.stride = stride
            self.res_scale = res_scale
            self.softmax_scale = softmax_scale
            self.scale = [1-i/10 for i in range(level)]
            self.average = average
            escape_NaN = torch.FloatTensor([1e-4])
            self.register_buffer('escape_NaN', escape_NaN)
            self.conv_match_L_base = BasicBlock(conv,channel,channel//reduction, 1, bn=False, act=nn.PReLU())
            self.conv_match = BasicBlock(conv,channel, channel//reduction, 1, bn=False, act=nn.PReLU())
            self.conv_assembly = BasicBlock(conv,channel, channel,1,bn=False, act=nn.PReLU())
    
        def forward(self, input):
            res = input
            #theta
            match_base = self.conv_match_L_base(input)
            shape_base = list(res.size())
            input_groups = torch.split(match_base,1,dim=0)
            # patch size for matching 
            kernel = self.ksize
            # raw_w is for reconstruction
            raw_w = []
            # w is for matching
            w = []
            #build feature pyramid
            for i in range(len(self.scale)):    
                ref = input
                if self.scale[i]!=1:
                    ref  = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic',
                    align_corners=True,recompute_scale_factor=True)
                #feature transformation function f
                base = self.conv_assembly(ref)
                shape_input = base.shape
                #sampling
                raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
                                          strides=[self.stride,self.stride],
                                          rates=[1, 1],
                                          padding='same') # [N, C*k*k, L]
                raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
                raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3)    # raw_shape: [N, L, C, k, k]
                raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
                raw_w.append(raw_w_i_groups)
    
                #feature transformation function g
                ref_i = self.conv_match(ref)
                shape_ref = ref_i.shape
                #sampling
                w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
                                      strides=[self.stride, self.stride],
                                      rates=[1, 1],
                                      padding='same')
                w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
                w_i = w_i.permute(0, 4, 1, 2, 3)    # w shape: [N, L, C, k, k]
                w_i_groups = torch.split(w_i, 1, dim=0)
                w.append(w_i_groups)
    
            y = []
            for idx, xi in enumerate(input_groups):
                #group in a filter
                wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))],dim=0)  # [L, C, k, k]
                #normalize
                max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
                                                         axis=[1, 2, 3],
                                                         keepdim=True)),
                                   self.escape_NaN)
                wi_normed = wi/ max_wi
                #matching
                xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1])  # xi: 1*c*H*W
                yi = F.conv2d(xi, wi_normed, stride=1)   # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
                yi = yi.view(1,wi.shape[0], shape_base[2], shape_base[3])  # (B=1, C=32*32, H=32, W=32)
                # softmax matching score
                yi = F.softmax(yi*self.softmax_scale, dim=1)
                
                if self.average == False:
                    yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()
                
                # deconv for patch pasting
                raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))],dim=0)
                yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride,padding=1)/4.
                y.append(yi)
          
            y = torch.cat(y, dim=0)+res*self.res_scale  # back to the mini-batch
            return y
    

# 模型部分

  • PANet使用的是SRResNet的骨干

    class ResBlock(nn.Module):
        def __init__(self,inChannals):
            super().__init__()
            self.model = nn.Sequential(
                nn.Conv2d(inChannals,inChannals,kernel_size = 1,bias = False),
                nn.BatchNorm2d(inChannals),
                nn.ReLU(inplace = True),
                nn.Conv2d(inChannals,inChannals,kernel_size = 3,stride = 1,
                          padding = 1,bias = False,padding_mode = 'reflect'),
                nn.BatchNorm2d(inChannals)
            )
            
        def forward(self,input):
            return F.relu(input + self.model(input),inplace = True)
    
    class Sequential(nn.Sequential):
        def __init__(self,inChannals,blockNum = 8):
            seq = [ResBlock(inChannals) for _ in range(blockNum)]
            seq.insert(int(blockNum/2),PyramidAttention(channel=inChannals, level=4))
            super().__init__(*seq)
            
    class Model(nn.Module):
        def __init__(self,channals = 64,blockNum = 6):
            super().__init__()
            self.features = nn.Sequential(
                nn.Conv2d(3,channals,kernel_size = 7,padding = 3,stride = 1,
                          padding_mode = 'reflect',bias = False),
                nn.BatchNorm2d(channals),
                nn.ReLU(inplace = True),
                nn.Conv2d(channals,channals,kernel_size = 3,padding = 1,stride = 1,
                         padding_mode = 'reflect',bias = False),
                nn.BatchNorm2d(channals),
                nn.ReLU(inplace = True)
            )
            
            self.sequential = Sequential(channals,blockNum)
            
            self.upSample = nn.Sequential(
                nn.Conv2d(channals,channals * 4,kernel_size = 3,padding = 1,stride = 1,
                         padding_mode = 'reflect'),
                
                nn.PixelShuffle(2),
                nn.Conv2d(channals,channals,kernel_size = 3,padding = 1,stride = 1),
                nn.ReLU(inplace = True),
                nn.Conv2d(channals,3,kernel_size = 1,stride = 1),
                nn.Tanh()
            )
            
        def forward(self,input):
            features = self.features(input)
            output = self.sequential(features)
            output = features + output
            output = self.upSample(output)
    
            return output
    
  • 最后,通过简单的方式我们便可构建一个模型

    #如果电脑可以使用显卡,则自动使用显卡加速
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #创建网络模型 
    net = Model(channals = 64,blockNum = 24).to(device)
    

4.2.4 准备训练配件

  • 为了对模型进行训练和验证,我们需要以下部件:优化器Optimizer损失函数Criteria评估标注

# 优化器

  • 优化器我们使用了AdamW
    optimizer = optim.AdamW(net.parameters(),lr = 1e-4)
    

# 损失函数

  • 损失函数我们参考了原作者,使用了L1 Loss
    criteria = nn.L1Loss()
    

# 评估标准

  • 我们使用了SSIMPSRN两种标注,他们的代码如下所示:
## PSRN
  • 代码如下:
    def PSRN(img1, img2):
        mse = torch.mean((img1 - img2) ** 2)
        if mse < 1.0e-10:
            return 100
        return 10 * math.log10(255.0**2/mse)
    
## SSIM
  • 代码如下:
    def gaussian(window_size, sigma):
        gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
        return gauss/gauss.sum()
    
    def create_window(window_size, channel):
        _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
        return window
    
    def _ssim(img1, img2, window, window_size, channel, size_average = True):
        mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
        mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
    
        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1*mu2
    
        sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
        sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
        sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
    
        C1 = 0.01**2
        C2 = 0.03**2
    
        ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
    
        if size_average:
            return ssim_map.mean()
        else:
            return ssim_map.mean(1).mean(1).mean(1)
    
    
    def ssim(img1, img2, window_size = 11, size_average = True):
        (_, channel, _, _) = img1.size()
        window = create_window(window_size, channel)
        
        if img1.is_cuda:
            window = window.cuda(img1.get_device())
        window = window.type_as(img1)
        
        return _ssim(img1, img2, window, window_size, channel, size_average)
    
    • 好吧,这两个都是借鉴别人的(老懒狗了

4.2.5 构建训练框架

  • 训练框架如下所示:
if __name__ == '__main__':
    path = '你的数据集路径'
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = PreprocessDataset(path,size = 96)
    trainData = DataLoader(dataset,batch_size = 32,num_workers = 4,shuffle = True)

    net = Model(channals = 64,blockNum = 24).to(device)
    print(net)
    criteria = nn.L1Loss()
    optimizer = optim.AdamW(net.parameters(),lr = 1e-4)
    totalStep = len(trainData)

    # 构建可视化结果的保存路径
    if not os.path.exists('./img'):
        os.mkdir('./img')

    for epoch in range(startEpoch,10000):
        if epoch == 20 or epoch == 40:
            update_lr(optimizer, multiplier = .1)

        totalSSIM = 0.0
        totalPSRN = 0.0
        totalLoss = 0.0
        
        for step,(hr,lr) in enumerate(trainData,1):
            net.train(True)
            hr,lr = hr.to(device),lr.to(device)
            net.zero_grad()
            output = net(lr)
            loss = criteria(output,hr)
            loss.backward()
            optimizer.step()

            totalLoss += loss
            totalSSIM += ssim(output,hr)
            totalPSRN += PSRN(output,hr)
            
            print("[Epoch %d] Step: %d/%d Loss: %.4f|ssim: %.4f|psrn: %.4f" %
                (epoch,step,totalStep,totalLoss/step,totalSSIM/step,totalPSRN/step))

            if step >= 100: #对图像进行可视化
                net.train(False)
                outputs = net(lr)
                outputs = torch.cat([hr,outputs],dim = 0)
                save_image(outputs,'./Img/Result_epoch_%08d.jpg' % epoch,nrow = 8,normalize = True)      
  • 完整版代码支持重新打开代码自动恢复到上次训练的功能,只需要关注笔者即可获得:传送门

4.2.6 训练结果

  • 100次训练后结果:
    在这里插入图片描述

  • 10,000次训练后结果:
    在这里插入图片描述
    此时SSIM:0.6710PSRN:65.6384

  • 由于COCO数据集中的特征不唯一,因此需要更多的训练才能够达到更好的结果。

  • 完整版代码支持重新打开代码自动恢复到上次训练的功能,只需要关注笔者即可获得:传送门

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-26 10:10:08  更:2021-09-26 10:11:59 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年4日历 -2024/4/26 17:18:13-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码