一、本文的实例说明
本文旨在用Pytorch构建一个GAN网络,这个GAN网络可以生成手写数字。
二、GAN原理说明
这快不做赘述,CSDN上(及baidu上)关于GAN(生成对抗网络)的说明实在太多,这里推荐一篇文章,写的通俗易懂:一文看懂「生成对抗网络 - GAN」基本原理+10种典型算法+13种应用
三、GAN网络架构说明
GAN由生成器(generator)和判别器(discriminator)组成。
1)生成器架构
由9个CBR模块串联形成,结构及参数如下:
CBR=C+B+R C=ConvTranspose *注意!这里是逆卷积,因为生成器要把一个简单的向量(或者数值)生成一个图片(矩阵),这是一个“扩大”(上采样)的过程,所以要用逆卷积。这里再推荐一篇文章:ConvTranspose2d原理,深度网络如何进行上采样? B=Batch Normalization; R=ReLU;
2)判别器架构
也有9层,由9个CBL模块串联组成,结构及参数如下:
CBL=C+B+L C=Conv *这里就是卷积层; B=Batch Normalization; L=LeakyReLu;
3)训练数据
从网上下载图片格式的MNIST数据集,然后取前900个训练(当然,计算机性能允许的话MNIST数据全部拿来训练更好。图片格式的MNIST数据集一般要付费,如果需要请留邮箱)
四、Pytorch代码
附在最后
五、生成结果
取训练过程的前100个epcoh的图片,可以看出已经基本能生成一个比较像样的“9”,还有比较模糊的“7”和“8”。
六、一些理解
1)为什么在代码中生成器每训练5次判别器才训练一次?
直观理解,相比于“识别”图像,“创造”图像是一个更加复杂的任务,所以训练的次数要更多。从loss上也可以看出。(蓝色为生成器loss,红色为判别器loss)
2)为什么最终生成的数字还是不太清晰?
个人理解,按影响从大到小有以下3个方面: ①网络模型不太合理:本次只采用了CBR模块的简单串联,如果加入些池化层,全连接层,网络可能不用这么“深”,而且效果可能更好;
写完这篇文章之后,发现确实有不少用GAN生成手写数字的实例,基本都是用全连接层做的,而且效果都不错。但是对于复杂的图像肯定是要用到CNN卷积神经网络的,比如生成Dota2英雄头像: 没错,最开始我是想做这个实例的。但是无奈做了几次都不成功,最大的问题可能是因为训练数据太少了,英雄头像总共就123个,而且差异非常大(有人类,精灵,有没有眼睛的,没有嘴的,既没有眼睛也没有嘴的,有一个头的,两个头的,三个头的。。。。)
②设置参数不合理:卷积层的Channel数量,Kernel size,stride,padding,learning rate等等这些都有影响; ③训练数据样本太少:参考上面Dota2头像的说明,但是MNIST数据集确实够大了,这个原因应该影响不大。
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
img_size = 32
batch_size = 100
max_epoch = 200
init_channel = 100
class Gen_net(nn.Module):
def __init__(self):
super(Gen_net, self).__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(in_channels=init_channel, out_channels=768, kernel_size=3, stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=768, out_channels=384, kernel_size=3, stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=384, out_channels=192, kernel_size=3, stride=2,
padding=0,
bias=False),
nn.BatchNorm2d(192),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=192, out_channels=96, kernel_size=2, stride=2, padding=0,
bias=False),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=96, out_channels=48, kernel_size=3, stride=1, padding=0, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=48, out_channels=24, kernel_size=3, stride=1, padding=0, bias=False),
nn.BatchNorm2d(24),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=24, out_channels=12, kernel_size=3, stride=1, padding=0, bias=False),
nn.BatchNorm2d(12),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=12, out_channels=6, kernel_size=3, stride=1, padding=0, bias=False),
nn.BatchNorm2d(6),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=0, bias=False),
nn.BatchNorm2d(3),
nn.Sigmoid()
)
def forward(self,x):
return self.net(x)
class Dis_net(nn.Module):
def __init__(self):
super(Dis_net,self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels= 6, kernel_size= 3, stride= 1, padding= 0, bias=False),
nn.BatchNorm2d(6),
nn.LeakyReLU(0.2, inplace= True),
nn.Conv2d(in_channels= 6, out_channels=12, kernel_size= 4, stride= 1, padding= 0, bias=False),
nn.BatchNorm2d(12),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels= 12, out_channels= 24, kernel_size= 3, stride= 1, padding= 0,bias=False),
nn.BatchNorm2d(24),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels= 24, out_channels= 48, kernel_size= 4, stride=1, padding= 0, bias=False),
nn.BatchNorm2d(48),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels= 48, out_channels= 96, kernel_size= 3, stride= 1, padding= 0, bias=True),
nn.BatchNorm2d(96),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels=96, out_channels=192, kernel_size=4, stride=1, padding=0, bias=True),
nn.BatchNorm2d(192),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, stride=2, padding=0, bias=True),
nn.BatchNorm2d(384),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels=384, out_channels=768, kernel_size=4, stride=1, padding=0, bias=True),
nn.BatchNorm2d(768),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels=768, out_channels=1, kernel_size=5, stride=1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self,x):
return self.net(x).view(-1)
gen_net = Gen_net()
dis_net = Dis_net()
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize([img_size, img_size]),
torchvision.transforms.ToTensor()
])
datasets = torchvision.datasets.ImageFolder(root='train_net_pic', transform=transforms)
dataloader = DataLoader(datasets, batch_size=batch_size, num_workers=0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen_net.to(device)
dis_net.to(device)
loss = nn.BCELoss().to(device)
opt_gen = torch.optim.Adadelta(gen_net.parameters(), lr= 0.001)
opt_dis = torch.optim.Adadelta(dis_net.parameters(), lr= 0.001)
true_label = torch.ones(batch_size).to(device)
fake_label = torch.zeros(batch_size).to(device)
init_tensor = torch.randn(batch_size,init_channel,1,1).to(device)
gen_init_tensor = torch.randn(batch_size,init_channel,1,1).to(device)
test_init_tensor = torch.randn(batch_size,init_channel,1,1).to(device)
for epoch in range(max_epoch):
for it, (img, _) in tqdm(enumerate(dataloader)):
real_pic = img.to(device)
if it%5 == 0 :
opt_dis.zero_grad()
real_output = dis_net(real_pic)
dis_real_loss = loss(real_output, true_label)
fake_pic = gen_net(init_tensor.detach()).detach()
fake_output = dis_net(fake_pic)
dis_fake_loss = loss(fake_output, fake_label)
dis_loss = (dis_fake_loss + dis_real_loss)
dis_loss.backward()
opt_dis.step()
dis_loss_numpy = dis_loss.detach().numpy()
plt.scatter(epoch, dis_loss_numpy, c='r')
if it%1 == 0 :
opt_gen.zero_grad()
gen_pic = gen_net(gen_init_tensor)
gen_output = dis_net(gen_pic)
gen_loss = loss(gen_output, true_label)
gen_loss.backward()
opt_gen.step()
gen_loss_numpy = gen_loss.detach().numpy()
plt.scatter(epoch, gen_loss_numpy, c='b')
img = gen_net(test_init_tensor)
torchvision.utils.save_image(img.data[:8], "%s/GAN_MNISTER_deep_3rd_%s.png" % ('gen_pic_deep', epoch),
normalize=True)
torch.save(img.data[0] ,'gen_pic_deep/tensor.txt')
print('save_%s'%epoch)
plt.show()
|