IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> GFPGAN源码分析—第六篇 -> 正文阅读

[人工智能]GFPGAN源码分析—第六篇

2021SC@SDUSC

源码:archs\gfpganv1_clean_arch.py

本篇主要分析gfpganv1_clean_arch.py下的

class GFPGANv1Clean(nn.Module)类_init_()方法

目录

class GFPGANv1Clean(nn.Module)

init()

(1)channels的设置

(2)调用torch.nn.Conv2d()创建了一层卷积神经网络

(3)下采样(downsample)

(4)上采样(upsample)

(5)全连接层

(6)创建self.stylegan_decoder

(7)如果decoder_load_path不为空则读取

(8)for SFT(SFT layer)


class GFPGANv1Clean(nn.Module)

????????继承自nn.Module类,使得我们可以使用很多现成的类,比如本类中使用的Conv2d以及RelU激活函数等等。

init()

参数:

self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False

在class GFPGANer()-init()中被调用时:

self.gfpgan = GFPGANv1Clean(
    out_size=512,
    num_style_feat=512,
    channel_multiplier=channel_multiplier,
    decoder_load_path=None,
    fix_decoder=False,
    num_mlp=8,
    input_is_latent=True,
    different_w=True,
    narrow=1,
    sft_half=True)

(1)channels的设置

实际调用的时候narrow=1,

channels保存了经过convolution层后的输出的通道数

unet_narrow = narrow * 0.5

channels = {
    '4': int(512 * unet_narrow),
    '8': int(512 * unet_narrow),
    '16': int(512 * unet_narrow),
    '32': int(512 * unet_narrow),
    '64': int(256 * channel_multiplier * unet_narrow),
    '128': int(128 * channel_multiplier * unet_narrow),
    '256': int(64 * channel_multiplier * unet_narrow),
    '512': int(32 * channel_multiplier * unet_narrow),
    '1024': int(16 * channel_multiplier * unet_narrow)
}

(2)调用torch.nn.Conv2d()搭建卷积神经网络

#out_size=512,so log_size=9
self.log_size = int(math.log(out_size, 2))
#first_out_size = 512
first_out_size = 2 ** (int(math.log(out_size, 2)))
#channels['512']=32*2*0.5=32
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)

在这里介绍一下nn.Conv2d()的几个参数

in_channels: int,#输入的通道数目【必选】
out_channels: int,# 输出的通道数目【必选】
kernel_size: _size_2_t,#卷积核的大小,类型为int(方形边长) 或者元组(长和宽)【必选】
stride: _size_2_t = 1,#步长
padding: Union[str, _size_2_t] = 0,#边界增益,可以控制输出结果的尺寸
dilation: _size_2_t = 1,#控制卷积核之间的间距
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',  # TODO: refine this type
device=None,
dtype=None

那么可以得知

self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)


#实际上是传入通道为3(RGB)的输入,使用边长为1的卷积核,最后获得通道为32的输出
#由于卷积核边长为1,我们输入与输入的图片大小仍然保持一致,但增加了通道数

(3)下采样(downsample)

可以看到实际上是调用ResBlock做了下采样

# 输入图片的通道数(实际为32)
in_channels = channels[f'{first_out_size}']
 #创建ModuleList容器
self.conv_body_down = nn.ModuleList()
# i从self.log_size(9)->3      :7次循环
for i in range(self.log_size, 2, -1):
    out_channels = channels[f'{2 ** (i - 1)}']
    #调用ResBlock残差网络做下采样,并将该module添加到设置的ModuleList
    self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
    #这一层的输出管道数作为下一层输入的管道数
    in_channels = out_channels

介绍一下nn.ModuleList()

nn.ModuleList,它是一个储存不同module,并自动将每个 module 的 parameters 添加到网络之中的容器。你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。
#注意nn.ModuleList则没有实现内部forward函数,所以需要手动实现

最后一层卷积层的搭建:

#最终输出通道数为channels['4']=256,使用边长为3的卷积核,步长为1,padding为1,保证维度不变
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)

(4)上采样(upsample)

#输入通道数为channels['4']=256,即下采样的输出的通道数
        in_channels = channels['4']
        #创建ModuleList容器
        self.conv_body_up = nn.ModuleList()
        # i从3->self.log_size(9)     :7次循环
        for i in range(3, self.log_size + 1):
            # 定义输出的通道数
            out_channels = channels[f'{2 ** i}']
            # 调用带有上采样ResBlock残差网络,并将该module添加到设置的ModuleList
            self.conv_body_up.append(ResBlock(in_channels, out_channels, 
                                              mode='up'))
            #这一层的输出管道数作为下一层输入的管道数
            in_channels = out_channels

(5)全连接层

根据传入的参数different_w,选择每个输出样本的大小,并搭建相应的全连接层。

if different_w:
    #16*512=8192
    linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
    print(linear_out_channel)
else:
    #512
    linear_out_channel = num_style_feat
#全连接层size of each input sample:4096,size of each output sample:8192
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)

(6)创建self.stylegan_decoder

self.stylegan_decoder = StyleGAN2GeneratorCSFT(
    out_size=out_size,
    num_style_feat=num_style_feat,
    num_mlp=num_mlp,
    channel_multiplier=channel_multiplier,
    narrow=narrow,
    sft_half=sft_half)

(7)如果decoder_load_path不为空则读取

if decoder_load_path:
    self.stylegan_decoder.load_state_dict(
        torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
if fix_decoder:
    for name, param in self.stylegan_decoder.named_parameters():
        param.requires_grad = False

(8)for SFT(SFT layer)

#ModuleList
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
  # i从3->self.log_size(9)     :7次循环
for i in range(3, self.log_size + 1):
    # 定义输出的通道数
    out_channels = channels[f'{2 ** i}']
     #输出通道数是否减半
    if sft_half:
        sft_out_channels = out_channels
    else:
        sft_out_channels = out_channels * 2
         #使用nn.Sequential搭建网络,并添加到ModuleList
    self.condition_scale.append(
        nn.Sequential(
             #卷积核边长为3,步长为1,输出与输出保持相同维度
            nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, 
                                                                         True),
            nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
    self.condition_shift.append(
        nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, 
                                                                         True),
            nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))

nn.Sequential是一个有序的容器,其中传入的是构造器类(各种用来处理input的类),最终input会被Sequential中的构造器依次执行。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-07 12:01:51  更:2021-12-07 12:02:54 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 0:33:35-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码