前言.
- 关于
G
A
N
\rm GAN
GAN 的基础知识以及
D
C
G
A
N
\rm DCGAN
DCGAN 的特点,可以阅读《生成对抗网络》,这里将深度卷积生成对抗网络的特点摘录如下:
- 取消所有
p
o
o
l
i
n
g
\rm pooling
pooling 层,
G
e
n
e
r
a
t
o
r
\rm Generator
Generator 网络中使用转置卷积进行上采样,
D
i
s
c
r
i
m
i
n
a
t
o
r
\rm Discriminator
Discriminator 网络中用加入
s
t
r
i
d
e
\rm stride
stride 的卷积代替
p
o
o
l
i
n
g
.
\rm pooling.
pooling.
- 在
D
i
s
c
r
i
m
i
n
a
t
o
r
\rm Discriminator
Discriminator 和
G
e
n
e
r
a
t
o
r
\rm Generator
Generator 中均使用
B
a
t
c
h
?
N
o
r
m
a
l
i
z
a
t
i
o
n
.
\rm Batch~Normalization.
Batch?Normalization.
- 去掉全连接层,即线性层,使网络变为全卷积网络。
-
G
e
n
e
r
a
t
o
r
\rm Generator
Generator 网络中使用
R
e
L
U
\rm ReLU
ReLU 作为激活函数,最后一层使用
T
a
n
h
.
\rm Tanh.
Tanh.
-
D
i
s
c
r
i
m
i
n
a
t
o
r
\rm Discriminator
Discriminator 网络中使用
L
e
a
k
y
R
e
L
U
\rm LeakyReLU
LeakyReLU 作为激活函数。
判别网络.
-
D
C
G
A
N
\rm DCGAN
DCGAN 中判别网络
D
i
s
c
r
i
m
i
n
a
t
o
r
\rm Discriminator
Discriminator 与普通的卷积网络区别只在于使用带步长的卷积层来替换池化层,并且激活函数使用了
L
e
a
k
y
R
e
L
U
\rm LeakyReLU
LeakyReLU,该激活函数针对
R
e
L
U
\rm ReLU
ReLU 在输入为负时反向传播梯度为
0
0
0 进行修改,会对判别网络的效率有所改善。
-
P
y
T
o
r
c
h
\rm PyTorch
PyTorch 实现的
D
C
G
A
N
.
D
i
s
c
r
i
m
i
n
a
t
o
r
\rm DCGAN.Discriminator
DCGAN.Discriminator 代码如下所示:
class _DNet(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(3,64,4,2,1,bias = False),
nn.LeakyReLU(0.2,True),
nn.Conv2d(64,64*2,4,2,1,bias = False),
nn.BatchNorm2d(64*2),
nn.LeakyReLU(0.2,True),
nn.Conv2d(64*2,64*4,4,2,1,bias = False),
nn.BatchNorm2d(64*4),
nn.LeakyReLU(0.2,True),
nn.Conv2d(64*4,64*8,4,2,1,bias = False),
nn.BatchNorm2d(64*8),
nn.LeakyReLU(0.2,True),
nn.Conv2d(64*8,1,4,1,0,bias = False),
nn.Sigmoid())
def forward(self,inputs):
output = self.network(inputs)
output = output.view(-1,1).squeeze(1)
return output
- 判别网络
D
i
s
c
r
i
m
i
n
a
t
o
r
\rm Discriminator
Discriminator 对于批数据中的每个样本,其输出值都是一个标量,用于表示该样本是真实图片
t
r
u
e
\rm true
true 抑或是生成图片
f
a
k
e
.
\rm fake.
fake. 在本例中以
1
1
1 表示真实图片,
0
0
0 表示生成图片。
- 我们的输入图片大小为
3
×
64
×
64
3\times64\times64
3×64×64,注意判别网络中的卷积核参数,其卷积核大小为
4
4
4,步长为
2
2
2,填充为
1
1
1,这是一种常见的卷积核,其输出图像的大小计算式如下:
H
i
n
+
2
?
4
2
+
1
=
H
i
n
2
(1)
\frac{H_{in}+2-4}{2}+1=\frac{H_{in}}{2}\tag{1}
2Hin?+2?4?+1=2Hin??(1)
- 并且由于网络中添加了
B
a
t
c
h
?
N
o
r
m
a
l
i
z
a
t
i
o
n
\rm Batch~Normalization
Batch?Normalization 层,因此舍弃了卷积层偏置项
b
i
a
s
\rm bias
bias,具体的原因可以参考《卷积层操作后是否需要附加偏置项 bias.》
生成网络.
- 需要明确,生成网络
G
e
n
e
r
a
t
o
r
\rm Generator
Generator 接受一个随机向量来合成一张图片,在本例中使用的随机向量维度是
100
×
1
×
1.
100\times1\times1.
100×1×1.
-
P
y
T
o
r
c
h
\rm PyTorch
PyTorch 实现的生成网络
D
C
G
A
N
.
G
e
n
e
r
a
t
o
r
\rm DCGAN.Generator
DCGAN.Generator 代码如下所示:
class _GNet(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.ConvTranspose2d(100,64*8,4,1,0,bias = False),
nn.BatchNorm2d(64*8),
nn.ReLU(True),
nn.ConvTranspose2d(64*8,64*4,4,2,1,bias = False),
nn.BatchNorm2d(64*4),
nn.ReLU(True),
nn.ConvTranspose2d(64*4,64*2,4,2,1,bias = False),
nn.BatchNorm2d(64*2),
nn.ReLU(True),
nn.ConvTranspose2d(64*2,64,4,2,1,bias = False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64,3,4,2,1,bias = False),
nn.Tanh())
def forward(self,inputs):
output = self.network(inputs)
return output
- 注意生成网络中使用了转置卷积
n
n
.
C
o
n
v
T
r
a
n
s
p
o
s
e
2
d
(
)
\rm nn.ConvTranspose2d()
nn.ConvTranspose2d() 来进行上采样,最终网络的输出是大小为
3
×
64
×
64
3\times64\times64
3×64×64 的张量,代表一张合成图片。
训练过程.
-
D
C
G
A
N
\rm DCGAN
DCGAN 在完成了网络参数初始化后,首先从判别网络开始训练,分别用真实图片集和合成图片集对其进行训练,将误差加合起来进行反向传播。由于具体的分类问题实质是二分类问题,因此损失函数采用二分类交叉熵函数
n
n
.
B
C
E
L
o
s
s
\rm nn.BCELoss
nn.BCELoss.
- 我们将真实图片记为
x
x
x,输入判别网络后对应的结果为
D
(
x
)
D(x)
D(x);输入生成网络的随机向量记为
z
z
z,生成网络的输出结果为
G
(
z
)
G(z)
G(z),因此判别网络的判别结果为
D
[
G
(
z
)
]
.
D[G(z)].
D[G(z)].
-
B
C
E
L
o
s
s
\rm BCELoss
BCELoss 的损失值计算式如下:
l
(
y
^
,
y
)
=
?
[
y
?
log
?
y
^
+
(
1
?
y
)
?
log
?
(
1
?
y
^
)
]
(2)
l(\hat y,y)=-[y\cdot\log\hat y+(1-y)\cdot\log(1-\hat y)]\tag{2}
l(y^?,y)=?[y?logy^?+(1?y)?log(1?y^?)](2)因此对判别网络而言,真实图片集的损失为:
l
1
=
?
log
?
[
D
(
x
)
]
(3.1)
l_1=-\log\big[D(x)\big]\tag{3.1}
l1?=?log[D(x)](3.1)合成图片集的损失为:
l
2
=
?
log
?
[
1
?
D
[
G
(
z
)
]
]
(3.2)
l_2=-\log\big[1-D[G(z)]\big]\tag{3.2}
l2?=?log[1?D[G(z)]](3.2)将
l
1
,
l
2
l_1,l_2
l1?,l2? 加合后就是判别网络最终需要优化的损失函数:
l
d
=
?
log
?
[
D
(
x
)
]
?
log
?
[
1
?
D
[
G
(
z
)
]
]
(4)
l_d=-\log\big[D(x)\big]-\log\big[1-D[G(z)]\big]\tag{4}
ld?=?log[D(x)]?log[1?D[G(z)]](4)
- 而后训练生成网络,训练方式基于判别网络的反馈。即对于随机向量
x
x
x,将生成图片
G
(
z
)
G(z)
G(z) 传递给判别网络,得到判别结果
D
[
G
(
z
)
]
D[G(z)]
D[G(z)],籍此来进行梯度反向传播训练生成网络。这里需要注意,对生成网络而言,合成图片的目的是欺骗判别网络,因此合成图片在生成网络中的标签值应为
t
r
u
e
.
\rm true.
true. 同样采样二分类交叉熵函数计算损失值,得到生成网络的损失为:
l
g
=
?
log
?
[
D
[
G
(
z
)
]
]
(5)
l_g=-\log\big[D[G(z)]\big]\tag{5}
lg?=?log[D[G(z)]](5)
- 注意在该过程中,判定合成图片的真假会有判别网络参与,但这里我们必须固定判别网络的参数,只对生成网络进行训练。所以实际上训练生成网络的过程实际在训练生成-判别串接网络,只不过判别网络的参数不进行更新。
- 关于训练过程,还可以参看《大白话讲解
G
A
N
\rm GAN
GAN 的训练机制》互相验证。
- 下面分步骤给出
D
C
G
A
N
\rm DCGAN
DCGAN 训练阶段的代码,分别对应上面所说的训练过程。
- 首先是使用真实图片集对判别网络进行训练:
DNet.zero_grad()
r_sample = data[0]
batch_size = r_sample.size()[0]
if is_cuda:
r_sample = r_sample.cuda()
inputs.resize_as_(r_sample).copy_(r_sample)
'''1 for real'''
labels.resize_(batch_size).fill_(1)
outputs = DNet(inputs)
r_DNetloss = loss(outputs,labels)
r_DNetloss.backward()
Dx = outputs.data.mean()
- 而后是使用合成图片集对判别网络进行训练,注意这里在训练判别网络时并不希望同时训练生成网络,因此需要切断梯度流。
noise = torch.FloatTensor(batch_size,100,1,1).normal_(0,1)
if is_cuda:
noise = noise.cuda()
f_sample = GNet(noise)
'''0 for fake'''
labels = labels.fill_(0)
outputs = DNet(f_sample.detach())
f_DNetloss = loss(outputs,labels)
f_DNetloss.backward()
DGz1 = outputs.data.mean()
DNetLoss = r_DNetloss + f_DNetloss
opt_d.step()
- 我们注意到在
r
_
D
N
e
t
l
o
s
s
\rm r\_DNetloss
r_DNetloss 和
f
_
D
N
e
t
l
o
s
s
\rm f\_DNetloss
f_DNetloss 都反向传播后才调用了
o
p
t
_
d
.
s
t
e
p
(
)
.
\rm opt\_d.step().
opt_d.step(). 并且
f
_
s
a
m
p
l
e
\rm f\_sample
f_sample 在传入判别网络时进行了
d
e
t
a
c
h
(
)
\rm detach()
detach() 操作。
- 最后是基于判别网络的反馈来训练生成网络:
GNet.zero_grad()
'''Fake sample is \real\ to generator net '''
labels = labels.fill_(1)
outputs = DNet(f_sample)
GNetLoss = loss(outputs,labels)
GNetLoss.backward()
DGz2 = outputs.data.mean()
opt_g.step()
DCGAN完整代码.
- 使用
C
I
F
A
R
10
\rm CIFAR10
CIFAR10 数据集进行训练,生成器最终生成的图片会和其中的图片十分类似。下面分别是真实图片和合成图片的展示:
- 合成图片中箭头所指已经十分像一只真正的🐱.
- 下面是本例中
D
C
G
A
N
\rm DCGAN
DCGAN 的完整代码,推荐使用
S
p
y
d
e
r
\rm Spyder
Spyder 或
N
o
t
e
b
o
o
k
\rm Notebook
Notebook 进行体验,每个
I
n
[
?
]
\rm In[~]
In[?] 是一个
c
e
l
l
.
\rm cell.
cell.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets
from PIL import Image
if torch.cuda.is_available():
print('GPU works.')
is_cuda = True
img_size = 64
mean = [0.5,0.5,0.5]
std = [0.5,0.5,0.5]
trans = transforms.Compose([transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
ds = datasets.CIFAR10(root = 'data',download = True,
transform = trans)
loader = torch.utils.data.DataLoader(ds,batch_size = 64,
shuffle = True)
class _DNet(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(3,64,4,2,1,bias = False),
nn.LeakyReLU(0.2,True),
nn.Conv2d(64,64*2,4,2,1,bias = False),
nn.BatchNorm2d(64*2),
nn.LeakyReLU(0.2,True),
nn.Conv2d(64*2,64*4,4,2,1,bias = False),
nn.BatchNorm2d(64*4),
nn.LeakyReLU(0.2,True),
nn.Conv2d(64*4,64*8,4,2,1,bias = False),
nn.BatchNorm2d(64*8),
nn.LeakyReLU(0.2,True),
nn.Conv2d(64*8,1,4,1,0,bias = False),
nn.Sigmoid())
def forward(self,inputs):
output = self.network(inputs)
output = output.view(-1,1).squeeze(1)
return output
class _GNet(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.ConvTranspose2d(100,64*8,4,1,0,bias = False),
nn.BatchNorm2d(64*8),
nn.ReLU(True),
nn.ConvTranspose2d(64*8,64*4,4,2,1,bias = False),
nn.BatchNorm2d(64*4),
nn.ReLU(True),
nn.ConvTranspose2d(64*4,64*2,4,2,1,bias = False),
nn.BatchNorm2d(64*2),
nn.ReLU(True),
nn.ConvTranspose2d(64*2,64,4,2,1,bias = False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64,3,4,2,1,bias = False),
nn.Tanh())
def forward(self,inputs):
output = self.network(inputs)
return output
def weight(model):
cls_name = model.__class__.__name__
if cls_name.find('Conv') != -1:
model.weight.data.normal_(0.0,0.02)
elif cls_name.find('BatchNorm') != -1:
model.weight.data.normal_(0.0,0.02)
model.bias.data.fill_(0)
DNet = _DNet()
GNet = _GNet()
DNet.apply(weight)
GNet.apply(weight)
print(DNet)
print(GNet)
loss = nn.BCELoss()
lr = 2e-4
betas = (0.5,0.999)
opt_d = optim.Adam(DNet.parameters(),lr,betas)
opt_g = optim.Adam(GNet.parameters(),lr,betas)
inputs = torch.FloatTensor(64,3,img_size,img_size)
labels = torch.FloatTensor(64)
noise = torch.FloatTensor(64,100,1,1)
fix_noise = torch.FloatTensor(64,100,1,1).normal_(0,1)
outpath = './output'
if is_cuda:
DNet = DNet.cuda()
GNet = GNet.cuda()
loss = loss.cuda()
inputs,labels = inputs.cuda(),labels.cuda()
noise,fix_noise = noise.cuda(),fix_noise.cuda()
epochs = 10
DNetLosses = []
GNetLosses = []
for epoch in range(epochs):
for batch,data in enumerate(loader,start = 0):
'''Train DNet'''
DNet.zero_grad()
r_sample = data[0]
batch_size = r_sample.size()[0]
if is_cuda:
r_sample = r_sample.cuda()
inputs.resize_as_(r_sample).copy_(r_sample)
'''1 for real'''
labels.resize_(batch_size).fill_(1)
outputs = DNet(inputs)
r_DNetloss = loss(outputs,labels)
r_DNetloss.backward()
Dx = outputs.data.mean()
noise = torch.FloatTensor(batch_size,100,1,1).normal_(0,1)
if is_cuda:
noise = noise.cuda()
f_sample = GNet(noise)
'''0 for fake'''
labels = labels.fill_(0)
outputs = DNet(f_sample.detach())
f_DNetloss = loss(outputs,labels)
f_DNetloss.backward()
DGz1 = outputs.data.mean()
DNetLoss = r_DNetloss + f_DNetloss
opt_d.step()
'''Train GNet'''
GNet.zero_grad()
'''Fake sample is \real\ to generator net '''
labels = labels.fill_(1)
outputs = DNet(f_sample)
GNetLoss = loss(outputs,labels)
GNetLoss.backward()
DGz2 = outputs.data.mean()
opt_g.step()
if batch % 100 == 0:
print('[%d/%d][%d/%d]'%(epoch,epochs,batch,len(loader)))
print('DNet-Loss:%.4f'%(DNetLoss.data))
print('GNet-Loss:%.4f'%(GNetLoss.data))
print('D(x):%.4f'%(Dx))
print('D[G(z)]:%.4f/%.4f'%(DGz1,DGz2))
print('*'*30)
vutils.save_image(r_sample,
'%s\\real_sample.png'%outpath,
normalize= True)
f_sample = GNet(fix_noise)
vutils.save_image(f_sample,
'%s\\fake_sample_%03d.png'%(outpath,epoch),
normalize= True)
DNetLosses.append(DNetLoss.detach().cpu().numpy())
GNetLosses.append(GNetLoss.detach().cpu().numpy())
plt.plot(range(1,len(DNetLosses)+1),
DNetLosses,'b',
label = 'GNet Loss')
plt.plot(range(1,len(GNetLosses)+1),
GNetLosses,'r',
label = 'DNet Loss')
plt.legend()
|