IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于Pytorch用GAN生成手写数字实例(附代码) -> 正文阅读

[人工智能]基于Pytorch用GAN生成手写数字实例(附代码)

一、本文的实例说明

本文旨在用Pytorch构建一个GAN网络,这个GAN网络可以生成手写数字。

二、GAN原理说明

这快不做赘述,CSDN上(及baidu上)关于GAN(生成对抗网络)的说明实在太多,这里推荐一篇文章,写的通俗易懂:一文看懂「生成对抗网络 - GAN」基本原理+10种典型算法+13种应用

三、GAN网络架构说明

GAN由生成器(generator)和判别器(discriminator)组成。

1)生成器架构

由9个CBR模块串联形成,结构及参数如下:
在这里插入图片描述

CBR=C+B+R
C=ConvTranspose *注意!这里是逆卷积,因为生成器要把一个简单的向量(或者数值)生成一个图片(矩阵),这是一个“扩大”(上采样)的过程,所以要用逆卷积。这里再推荐一篇文章:ConvTranspose2d原理,深度网络如何进行上采样?
B=Batch Normalization;
R=ReLU;

2)判别器架构

也有9层,由9个CBL模块串联组成,结构及参数如下:
在这里插入图片描述

CBL=C+B+L
C=Conv *这里就是卷积层;
B=Batch Normalization;
L=LeakyReLu;

3)训练数据

从网上下载图片格式的MNIST数据集,然后取前900个训练(当然,计算机性能允许的话MNIST数据全部拿来训练更好。图片格式的MNIST数据集一般要付费,如果需要请留邮箱)

四、Pytorch代码

附在最后

五、生成结果

取训练过程的前100个epcoh的图片,可以看出已经基本能生成一个比较像样的“9”,还有比较模糊的“7”和“8”。
在这里插入图片描述

六、一些理解

1)为什么在代码中生成器每训练5次判别器才训练一次?

直观理解,相比于“识别”图像,“创造”图像是一个更加复杂的任务,所以训练的次数要更多。从loss上也可以看出。(蓝色为生成器loss,红色为判别器loss)
在这里插入图片描述

2)为什么最终生成的数字还是不太清晰?

个人理解,按影响从大到小有以下3个方面:
①网络模型不太合理:本次只采用了CBR模块的简单串联,如果加入些池化层,全连接层,网络可能不用这么“深”,而且效果可能更好;

写完这篇文章之后,发现确实有不少用GAN生成手写数字的实例,基本都是用全连接层做的,而且效果都不错。但是对于复杂的图像肯定是要用到CNN卷积神经网络的,比如生成Dota2英雄头像:
在这里插入图片描述
没错,最开始我是想做这个实例的。但是无奈做了几次都不成功,最大的问题可能是因为训练数据太少了,英雄头像总共就123个,而且差异非常大(有人类,精灵,有没有眼睛的,没有嘴的,既没有眼睛也没有嘴的,有一个头的,两个头的,三个头的。。。。)

②设置参数不合理:卷积层的Channel数量,Kernel size,stride,padding,learning rate等等这些都有影响;
③训练数据样本太少:参考上面Dota2头像的说明,但是MNIST数据集确实够大了,这个原因应该影响不大。


import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm  #tqdm含义tqdm derives from the Arabic word taqaddum (?????) which can mean “progress,” and is an abbreviation for “I love you so much” in Spanish (te quiero demasiado).
import matplotlib.pyplot as plt


img_size = 32
batch_size = 100
max_epoch = 200 #迭代次数,这个参数可以自己设计
init_channel = 100 #初始通道数,这个参数可以自己设计


# 数据集有900张图片(从MNIST选择900张图)
class Gen_net(nn.Module):

    def __init__(self):
        super(Gen_net, self).__init__()
        self.net = nn.Sequential(

            # 第一层
            nn.ConvTranspose2d(in_channels=init_channel, out_channels=768, kernel_size=3, stride=1,
                               padding=0,
                               bias=False),
            nn.BatchNorm2d(768),
            nn.ReLU(inplace=True),


            # 第二层
            nn.ConvTranspose2d(in_channels=768, out_channels=384, kernel_size=3, stride=1,
                               padding=0,
                               bias=False),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),


            # 第三层
            nn.ConvTranspose2d(in_channels=384, out_channels=192, kernel_size=3, stride=2,
                               padding=0,
                               bias=False),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),


            # 第四层
            nn.ConvTranspose2d(in_channels=192, out_channels=96, kernel_size=2, stride=2, padding=0,
                               bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),


            # 第五层
            nn.ConvTranspose2d(in_channels=96, out_channels=48, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(),


            #第六层
            nn.ConvTranspose2d(in_channels=48, out_channels=24, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(24),
            nn.ReLU(),


            #第七层
            nn.ConvTranspose2d(in_channels=24, out_channels=12, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(12),
            nn.ReLU(),


            #第八层
            nn.ConvTranspose2d(in_channels=12, out_channels=6, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(6),
            nn.ReLU(),


            #第九层
            nn.ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(3),
            nn.Sigmoid()
        )

    def forward(self,x):
        return self.net(x)


class Dis_net(nn.Module):
    def __init__(self):
        super(Dis_net,self).__init__()
        self.net = nn.Sequential(

            #第一层
            nn.Conv2d(in_channels=3, out_channels= 6, kernel_size= 3, stride= 1, padding= 0, bias=False),
            nn.BatchNorm2d(6),
            nn.LeakyReLU(0.2, inplace= True),


            #第二层
            nn.Conv2d(in_channels= 6, out_channels=12, kernel_size= 4, stride= 1, padding= 0, bias=False),
            nn.BatchNorm2d(12),
            nn.LeakyReLU(0.2, True),


            #第三层
            nn.Conv2d(in_channels= 12, out_channels= 24, kernel_size= 3, stride= 1, padding= 0,bias=False),
            nn.BatchNorm2d(24),
            nn.LeakyReLU(0.2, True),


            #第四层
            nn.Conv2d(in_channels= 24, out_channels= 48, kernel_size= 4, stride=1, padding= 0, bias=False),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.2, True),


            #第五层
            nn.Conv2d(in_channels= 48, out_channels= 96, kernel_size= 3, stride= 1, padding= 0, bias=True),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.2, True),


            #第六层
            nn.Conv2d(in_channels=96, out_channels=192, kernel_size=4, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(192),
            nn.LeakyReLU(0.2, True),


            #第七层
            nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(384),
            nn.LeakyReLU(0.2, True),


            #第八层
            nn.Conv2d(in_channels=384, out_channels=768, kernel_size=4, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(768),
            nn.LeakyReLU(0.2, True),


            #第九层
            nn.Conv2d(in_channels=768, out_channels=1, kernel_size=5, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self,x):
        return self.net(x).view(-1)


gen_net = Gen_net()
dis_net = Dis_net()

#图像处理
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize([img_size, img_size]),
    torchvision.transforms.ToTensor()
])

datasets = torchvision.datasets.ImageFolder(root='train_net_pic', transform=transforms)

dataloader = DataLoader(datasets, batch_size=batch_size, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gen_net.to(device)
dis_net.to(device)

#定义loss函数为二分交叉熵(discriminator的“评分”在0~1之间)
loss = nn.BCELoss().to(device)

#定义loss函数优化方法为Adadelta
opt_gen = torch.optim.Adadelta(gen_net.parameters(), lr= 0.001)
opt_dis = torch.optim.Adadelta(dis_net.parameters(), lr= 0.001)

true_label = torch.ones(batch_size).to(device)
fake_label = torch.zeros(batch_size).to(device)  #真实照片评分为1,假照片评分为0

init_tensor = torch.randn(batch_size,init_channel,1,1).to(device) #初始矩阵,给gen_net生成图像。(这个矩阵可以自己定义大小,并不限于1×1,但是对应的网络架构也要改)
gen_init_tensor = torch.randn(batch_size,init_channel,1,1).to(device)
test_init_tensor = torch.randn(batch_size,init_channel,1,1).to(device)

for epoch in range(max_epoch):   #迭代max_epoch次

    for it, (img, _) in tqdm(enumerate(dataloader)):  #遍历所有真实的图像( “_”是enumerate的用法,可以用变量代替)
        real_pic = img.to(device)
#----------------------训练dis_net----------------------
        if it%5 == 0 :
            opt_dis.zero_grad()  # zero_grad(), step()的用法参考前一篇文章
            real_output = dis_net(real_pic)  # 得到真实图片的discriminator网络输出一个0~1的“评分”,期望为1
            dis_real_loss = loss(real_output, true_label)  # 真实图片经过discriminator网络获得的输出和1矩阵的二分交叉熵作为loss

            fake_pic = gen_net(init_tensor.detach()).detach()
            fake_output = dis_net(fake_pic)  # 得到假图片(生成网络生成的图片)的discriminator网络输出一个0~1的“评分”,期望为0
            dis_fake_loss = loss(fake_output, fake_label)
            dis_loss = (dis_fake_loss + dis_real_loss)  # 判别网络的总loss

            dis_loss.backward()
            opt_dis.step()

            dis_loss_numpy = dis_loss.detach().numpy()
            plt.scatter(epoch, dis_loss_numpy, c='r')
#--------------------训练gen_net-------------------------
        if it%1 == 0 :
            opt_gen.zero_grad()
            gen_pic = gen_net(gen_init_tensor)
            gen_output = dis_net(gen_pic)
            gen_loss = loss(gen_output, true_label)  #期望生成的图片评分为1(生成网络期望“骗过”评价网络)

            gen_loss.backward()
            opt_gen.step()

            gen_loss_numpy = gen_loss.detach().numpy()
            plt.scatter(epoch, gen_loss_numpy, c='b')



    img = gen_net(test_init_tensor)
    torchvision.utils.save_image(img.data[:8], "%s/GAN_MNISTER_deep_3rd_%s.png" % ('gen_pic_deep', epoch),  #取每个batch的前8张图
                                 normalize=True)
    torch.save(img.data[0] ,'gen_pic_deep/tensor.txt')
    print('save_%s'%epoch)

plt.show()
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-10-08 20:42:06  更:2022-10-08 20:45:14 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 3:12:28-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计