通过使用卷积与Wasserstein GAN提升生成图像的质量
本文将实现deep convolutional Gan(DCGAN) ,同时也将实现Wasserstein GAN(WGAN) 。
本文使用到的一些技术:
from IPython.display import Image
%matplotlib inline
转置卷积
虽然卷积运算常用于对特征空间进行下采样(例如,通过将步长设置为2,或者通过在卷积层之后添加池层),但转置卷积运算通常用于对特征空间进行上采样 。
转置卷积的目的是恢复输入矩阵的形状,而非恢复实际矩阵值 。
详细来讲:原始输入size为
n
×
n
n\times n
n×n,经过二维卷积操作之后得到的特征映射图的size为
m
×
m
m\times m
m×m。转置卷积的作用就是从size
m
×
m
m \times m
m×m的特征图重新得到size为
n
×
n
n \times n
n×n的特征图。集合表示如下:
Image(filename='images/17_09.png', width=700)
转置卷积与逆卷积(Transposed convolution and deconvolution)
转置卷积有时候也被称作小步长卷积 。深度学习里面与转置卷积非常接近的一种卷积操作叫做逆卷积deconvolution 。
实际上,逆卷积被定义为常规卷积操作
f
f
f的逆运算。给定一个特征图feature map
x
\boldsymbol{x}
x,权重为
w
\boldsymbol{w}
w,卷积操作后得到特征映射图为
x
′
\boldsymbol{x}^{\prime}
x′。公式表示如下:
f
w
(
x
)
=
x
′
(公式1)
f_{w}(\boldsymbol{x})=\boldsymbol{x}^{\prime}\tag{公式1}
fw?(x)=x′(公式1)
对于逆卷积运算,记作
f
?
1
f^{-1}
f?1,公式表示如下:
f
w
?
1
(
f
(
x
)
)
=
x
(公式2)
f_{w}^{-1}(f(x))=x\tag{公式2}
fw?1?(f(x))=x(公式2)
转置卷积仅在意于恢复特征空间的维度,而不是实际值。
Image(filename='images/17_10.png', width=700)
批量归一化Batch normalization
主要思想是将层输入标准化,并且在训练期间防止其发生变化,从而实现更快更好地收敛。
假设一个四维张量
Z
\boldsymbol{Z}
Z,其shape为
[
m
×
h
×
w
×
c
]
[m \times h \times w \times c]
[m×h×w×c],对其进行卷积操作之后得到特征映射图。
这里的
m
m
m代表的是batch中的样本数量,或者称为batch size;
h
×
w
h\times w
h×w代表的是特征图的空间维度,
c
c
c代表的是通道数channels。
BatchNorm可以总结为如下三个步骤:
μ
B
=
1
m
×
h
×
w
∑
i
,
j
,
k
Z
[
i
,
j
,
k
,
]
,
σ
B
2
=
1
m
×
h
×
w
∑
i
,
j
,
k
(
Z
[
i
,
j
,
k
,
]
?
μ
B
)
2
,
?其中?
μ
B
?和?
σ
B
2
?都有size?c?
(公式3)
\boldsymbol{\mu}_{B}=\frac{1}{m \times h \times w} \sum_{i, j, k} \boldsymbol{Z}^{[i, j, k,]}, \boldsymbol{\sigma}_{B}^{2}=\frac{1}{m \times h \times w} \sum_{i, j, k}\left(\boldsymbol{Z}^{[i, j, k,]}-\boldsymbol{\mu}_{B}\right)^{2}, \text { 其中 } \boldsymbol{\mu}_{B} \text { 和 } \boldsymbol{\sigma}_{B}^{2}\text { 都有size c }\tag{公式3}
μB?=m×h×w1?i,j,k∑?Z[i,j,k,],σB2?=m×h×w1?i,j,k∑?(Z[i,j,k,]?μB?)2,?其中?μB??和?σB2??都有size?c?(公式3)
- 对批量batch中的所有样本进行标准化计算,如下:
Z
s
t
d
[
i
]
=
Z
[
i
]
?
μ
B
σ
B
+
?
(公式4)
\boldsymbol{Z}_{\mathrm{std}}^{[i]}=\frac{\boldsymbol{Z}^{[i]}-\boldsymbol{\mu}_{B}}{\boldsymbol{\sigma}_{B}+\epsilon}\tag{公式4}
Zstd[i]?=σB?+?Z[i]?μB??(公式4)其中,
?
\epsilon
?是一个确保数学稳定性(避免分母为零)的一个比较小的数。
- 使用两个可学习的参数向量缩放和转移归一化的网络输入,向量分别为
γ
and?
β
\gamma_{\text {and }} \boldsymbol{\beta}
γand??β.表示如下:
A
p
r
e
[
i
]
=
γ
Z
s
t
d
[
i
]
+
β
(公式5)
A_{\mathrm{pre}}^{[i]}=\gamma Z_{\mathrm{std}}^{[i]}+\boldsymbol{\beta}\tag{公式5}
Apre[i]?=γZstd[i]?+β(公式5)
几何表示如下:
Image(filename='images/17_11.png', width=700)
Tensorflow框架提供了tf.keras.layers.BatchNormalization() ,因此可以直接调用来定义模型。
同时上面的两个可学习参数
γ
and?
β
\gamma_{\text {and }} \boldsymbol{\beta}
γand??β可以通过training=False 和training=True 来实现对是否可以学习进行控制。
为什么BatchNorm有助于优化?
起初,BatchNorm的开发是为了减少所谓的内部协方差偏移 ,它代表的是在训练过程中,由于网络参数更新而导致的层的激活值分布发生的变化。
考虑一个固定的batch,其在epoch为1的时候通过网络,并将该批次的每一层的activations记录下来。
在遍历整个训练数据集并更新模型参数后,开始第二轮的epoch,其中,先前固定的patch通过网络。然后比较第一个和第二个epochs的层激活。
由于网络参数发生了变化,因此我们观察到activations也发生了变化,这种现象称为内部协方差偏移internal covariance shift。他被确信会减慢NN的训练速度。
但是后期有学者S.Santurkar研究指出,BatchNorm的使用,会导致损失函数的曲面更加光滑,从而在处理非凸优化问题中有着更加稳健的效果。
相关文献参考链接
实现generator和discriminator
Image(filename='images/17_12.png', width=700)
初始化输入向量
z
\boldsymbol{z}
z作为输入,同时使用全连接而将size提升到6272,然后reshape到3阶张量,形状为7x7x128(空间维度为7x7,128通道)。
然后经过一系列的转置卷积tf.keras.layers.Conv2DTransposed() 实现上采样,直到空间维度达到28x28。然后,通道数每次减半,但最后一个除外。
最后一个仅使用一个输出滤波器来产生一个灰度图像。
每一个转置卷积层都紧跟BatchNorm和Leaky ReLU激活函数,但最后一个除外,其使用的是tanh激活函数。
Image(filename='images/17_13.png', width=700)
判别器接收的是28x28x1的图像,其经过了4个卷积层。
通常不建议在BatchNorm层之后的层中使用偏置单元。在这种情况下,偏置单元是多余的,因为BatchNorm已经有了一个偏移参数 。
这个可以通过‘use_bial=False’进行设置。
- Setting up the Google Colab
import tensorflow as tf
tf.config.list_physical_devices('GPU')
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
import tensorflow as tf
print(tf.__version__)
print("GPU Available:", tf.config.list_physical_devices('GPU'))
if tf.test.is_gpu_available():
device_name = tf.test.gpu_device_name()
else:
device_name = 'CPU:0--测试名称'
print(device_name)
2.1.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
WARNING:tensorflow:From <ipython-input-11-f834e5116ee5>:8: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
/device:GPU:0
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
def make_dcgan_generator(
z_size=20,
output_size=(28, 28, 1),
n_filters=128,
n_blocks=2):
size_factor = 2**n_blocks
hidden_size = (
output_size[0]//size_factor,
output_size[1]//size_factor
)
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(z_size,)),
tf.keras.layers.Dense(
units=n_filters*np.prod(hidden_size),
use_bias=False),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Reshape(
(hidden_size[0], hidden_size[1], n_filters)),
tf.keras.layers.Conv2DTranspose(
filters=n_filters, kernel_size=(5, 5), strides=(1, 1),
padding='same', use_bias=False),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU()
])
nf = n_filters
for i in range(n_blocks):
nf = nf // 2
model.add(
tf.keras.layers.Conv2DTranspose(
filters=nf, kernel_size=(5, 5), strides=(2, 2),
padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(
tf.keras.layers.Conv2DTranspose(
filters=output_size[2], kernel_size=(5, 5),
strides=(1, 1), padding='same', use_bias=False,
activation='tanh'))
return model
def make_dcgan_discriminator(
input_size=(28, 28, 1),
n_filters=64,
n_blocks=2):
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=input_size),
tf.keras.layers.Conv2D(
filters=n_filters, kernel_size=5,
strides=(1, 1), padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU()
])
nf = n_filters
for i in range(n_blocks):
nf = nf*2
model.add(
tf.keras.layers.Conv2D(
filters=nf, kernel_size=(5, 5),
strides=(2, 2),padding='same'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Conv2D(
filters=1, kernel_size=(7, 7), padding='valid'))
model.add(tf.keras.layers.Reshape((1,)))
return model
gen_model = make_dcgan_generator()
gen_model.summary()
disc_model = make_dcgan_discriminator()
disc_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 6272) 125440
_________________________________________________________________
batch_normalization (BatchNo (None, 6272) 25088
_________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 6272) 0
_________________________________________________________________
reshape (Reshape) (None, 7, 7, 128) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 7, 7, 128) 409600
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 7, 128) 512
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 7, 7, 128) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 64) 204800
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 64) 256
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 14, 14, 64) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 32) 51200
_________________________________________________________________
batch_normalization_3 (Batch (None, 28, 28, 32) 128
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 28, 28, 32) 0
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 28, 28, 1) 800
=================================================================
Total params: 817,824
Trainable params: 804,832
Non-trainable params: 12,992
_________________________________________________________________
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 28, 64) 1664
_________________________________________________________________
batch_normalization_4 (Batch (None, 28, 28, 64) 256
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 28, 28, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 14, 14, 128) 204928
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 128) 512
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 14, 14, 128) 0
_________________________________________________________________
dropout (Dropout) (None, 14, 14, 128) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 7, 7, 256) 819456
_________________________________________________________________
batch_normalization_6 (Batch (None, 7, 7, 256) 1024
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 7, 7, 256) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 7, 7, 256) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 1, 1, 1) 12545
_________________________________________________________________
reshape_1 (Reshape) (None, 1) 0
=================================================================
Total params: 1,040,385
Trainable params: 1,039,489
Non-trainable params: 896
_________________________________________________________________
度量两种分布差异的不同方法
使用
P
(
x
)
P(x)
P(x)和
Q
(
x
)
Q(x)
Q(x)代表随机变量
x
x
x的分布,几何表示如下:
Image(filename='images/17_14.png', width=700)
在
T
V
TV
TV度量方法中,上确界(最小上界)函数supremum function代表的是:大于S中所有元素的值的最小值;
换句话说,
s
u
p
(
S
)
sup(S)
sup(S)是
S
S
S的最小上届。
-
TV 距离:衡量了两个分布对于所有数据点的最大差异; -
EM 距离:可以理解为将一个分布转换为另一个分布的所需要做的最小工作量;EM距离中的下确界函数被替换了; -
KL和JS散度来自于信息论领域。与JS散度相反,KL散度不对称,即
K
L
(
P
∥
Q
)
≠
K
L
(
Q
∥
P
)
K L(P \| Q) \neq K L(Q \| P)
KL(P∥Q)?=KL(Q∥P)。
下面是一些计算示例:
Image(filename='images/17_15.png', width=800)
KL散度和交叉熵之间的关系:
KL散度衡量的是分布P相对于一个参考分布Q的相对熵 ,KL散度的计算公式为:
K
L
(
P
∥
Q
)
=
?
∫
P
(
x
)
log
?
(
Q
(
x
)
)
d
x
?
(
?
∫
P
(
x
)
log
?
(
P
(
x
)
)
)
d
x
(公式6)
K L(P \| Q)=-\int P(x) \log (Q(x)) d x-\left(-\int P(x) \log (P(x))\right)d x\tag{公式6}
KL(P∥Q)=?∫P(x)log(Q(x))dx?(?∫P(x)log(P(x)))dx(公式6)
对于离散分布:KL散度表示如下
K
L
(
P
∥
Q
)
=
?
∑
i
P
(
x
i
)
P
(
x
i
)
Q
(
x
i
)
(公式7)
K L(P \| Q)=-\sum_{i} P\left(x_{i}\right) \frac{P\left(x_{i}\right)}{Q\left(x_{i}\right)}\tag{公式7}
KL(P∥Q)=?i∑?P(xi?)Q(xi?)P(xi?)?(公式7)
也可以类似表示为如下形式:
K
L
(
P
∥
Q
)
=
?
∑
i
P
(
x
i
)
log
?
(
Q
(
x
i
)
)
?
(
?
∑
i
P
(
x
i
)
log
?
(
P
(
x
i
)
)
)
(公式8)
K L(P \| Q)=-\sum_{i} P\left(x_{i}\right) \log \left(Q\left(x_{i}\right)\right)-\left(-\sum_{i} P\left(x_{i}\right) \log \left(P\left(x_{i}\right)\right)\right)\tag{公式8}
KL(P∥Q)=?i∑?P(xi?)log(Q(xi?))?(?i∑?P(xi?)log(P(xi?)))(公式8)
KL散度可以看做是P和Q之间的交叉熵减去P自身的熵:
K
L
(
P
∥
Q
)
=
H
(
P
,
Q
)
?
H
(
P
)
(公式9)
K L(P \| Q)=H(P, Q)-H(P)\tag{公式9}
KL(P∥Q)=H(P,Q)?H(P)(公式9)
相关论文表明,使用JS散度训练GAN模型可能会有问题,建议使用EM距离
使用EM距离的好处:
参考于文献---Wasserstein GAN:假设现有两个分布P和Q,他们是平行线 。其中一条是y轴,另一条是
x
=
θ
x=\theta
x=θ,其中
θ
>
0
\theta>0
θ>0。
此时,计算KL、TV,JS则有:
K
L
(
P
∥
Q
)
=
+
∞
,
T
V
(
P
,
Q
)
=
1
,
?and?
J
S
(
P
,
Q
)
=
1
2
log
?
2
(公式10)
K L(P \| Q)=+\infty, T V(P, Q)=1, \text { and } J S(P, Q)=\frac{1}{2} \log 2\tag{公式10}
KL(P∥Q)=+∞,TV(P,Q)=1,?and?JS(P,Q)=21?log2(公式10)
这样一来,上面所有的距离均与
θ
\theta
θ无关,而对应的EM距离为:
E
M
(
P
,
Q
)
=
∣
θ
∣
(公式11)
E M(P, Q)=|\theta|\tag{公式11}
EM(P,Q)=∣θ∣(公式11)
在GANs中使用EM距离作为度量标准
假设
P
r
P_r
Pr?为真实样本的分布,
P
g
P_g
Pg?为生成样本的分布。其分别对应于EM距离公式中的
P
P
P与
Q
Q
Q。
EM距离的计算是一个最优化问题,因此对计算能力有较大的需求。所幸,可以通过使用Kantorovich-Rubinstein duality 进行简化:
W
(
P
r
,
P
g
)
=
sup
?
∥
f
∥
L
≤
1
E
u
∈
P
r
[
f
(
u
)
]
?
E
v
∈
P
g
[
f
(
v
)
]
(公式12)
W\left(P_{r}, P_{g}\right)=\sup _{\|f\|_{L} \leq 1} E_{u \in P_{r}}[f(u)]-E_{v \in P_{g}}[f(v)]\tag{公式12}
W(Pr?,Pg?)=∥f∥L?≤1sup?Eu∈Pr??[f(u)]?Ev∈Pg??[f(v)](公式12)
这里的上确界supremum是指所有的1-Lipschitz连续函数,表示为:
∥
f
∥
L
≤
1
\|f\|_{L} \leq 1
∥f∥L?≤1
Lipschitz continuity:
基于1-Lipschitz连续型,函数
f
f
f必须满足如下条件:
∣
f
(
x
1
)
?
f
(
x
2
)
∣
≤
∣
x
1
?
x
2
∣
(公式13)
\left|f\left(x_{1}\right)-f\left(x_{2}\right)\right| \leq\left|x_{1}-x_{2}\right|\tag{公式13}
∣f(x1?)?f(x2?)∣≤∣x1??x2?∣(公式13)
进一步,一个实际的函数
f
:
R
→
R
f: R\rightarrow R
f:R→R可能满足如下条件:
∣
f
(
x
1
)
?
f
(
x
2
)
∣
≤
K
∣
x
1
?
x
2
∣
(公式14)
\left|f\left(x_{1}\right)-f\left(x_{2}\right)\right| \leq K\left|x_{1}-x_{2}\right|\tag{公式14}
∣f(x1?)?f(x2?)∣≤K∣x1??x2?∣(公式14)
使用Wasserstein distance训练GAN模型,对于生成器和判别器的损失分别定义如下:
对于判别器----真实样本:
L
real?
D
=
?
1
N
∑
i
D
(
x
i
)
(公式15)
L_{\text {real }}^{D}=-\frac{1}{N} \sum_{i} D\left(\boldsymbol{x}_{i}\right)\tag{公式15}
Lreal?D?=?N1?i∑?D(xi?)(公式15)
对于判别器----生成样本:
L
fake?
D
=
1
N
∑
i
D
(
G
(
z
i
)
)
(公式16)
L_{\text {fake }}^{D}=\frac{1}{N} \sum_{i} D\left(G\left(\mathbf{z}_{i}\right)\right)\tag{公式16}
Lfake?D?=N1?i∑?D(G(zi?))(公式16)
对于生成器:
L
G
=
?
1
N
∑
i
D
(
G
(
z
i
)
)
(公式17)
L^{G}=-\frac{1}{N} \sum_{i} D\left(G\left(\mathbf{z}_{i}\right)\right)\tag{公式17}
LG=?N1?i∑?D(G(zi?))(公式17)
梯度惩罚Gradient penalty (GP)
带有梯度惩罚的GAN模型----WGAN-GP:
相关文献指出,在GAN网络中,使用权重裁剪可能会导致梯度消失或者梯度爆炸
因此本文没有采用权重裁剪,而尝试使用梯度惩罚(gradient penalty—GP),相对应的网络为WANG with gradient penalty—WANG-GP:
在每个iteration中添加GP的过程总结如下 :
-
在给定的一个batch中,对于每个real和fake样本,选择一个随机数
α
[
i
]
\alpha^{[i]}
α[i],其是通过从均匀分布中采样得到,因此
α
[
i
]
∈
U
(
0
,
1
)
\alpha^{[i]} \in U(0,1)
α[i]∈U(0,1). -
计算真实样本和生成样本(fake)之间的插值:
x
?
[
i
]
=
α
x
[
i
]
+
(
1
?
α
)
x
~
[
i
]
\breve{\boldsymbol{x}}^{[i]}=\alpha \boldsymbol{x}^{[i]}+(1-\alpha) \widetilde{\boldsymbol{x}}^{[i]}
x?[i]=αx[i]+(1?α)x
[i],结果是产生了一批插值的例子; -
计算判别器网络相对于所有插值样本的输出:
D
(
x
ˇ
[
i
]
)
D\left(\check{\boldsymbol{x}}^{[i]}\right)
D(xˇ[i]); -
计算判别器万国相对于所有插值示例的梯度:
?
x
?
[
i
]
D
(
x
?
[
i
]
)
\nabla_{\breve{\boldsymbol{x}}^{[i]}} D\left(\breve{\boldsymbol{x}}^{[i]}\right)
?x?[i]?D(x?[i]); -
计算梯度惩罚GP:
L
g
p
D
=
1
N
∑
i
(
∥
?
x
?
[
i
]
D
(
x
?
[
i
]
)
∥
2
?
1
)
2
(公式18)
L_{g p}^{D}=\frac{1}{N} \sum_{i}\left(\left\|\nabla_{\breve{x}^{[i]}} D\left(\breve{x}^{[i]}\right)\right\|_{2}-1\right)^{2}\tag{公式18}
LgpD?=N1?i∑?(∥∥∥??x?[i]?D(x?[i])∥∥∥?2??1)2(公式18)
因此,对于判别器而言,总的损失为:
L
total?
D
=
L
real?
D
+
L
fake?
D
+
λ
L
g
p
D
(公式19)
L_{\text {total }}^{D}=L_{\text {real }}^{D}+L_{\text {fake }}^{D}+\lambda L_{g p}^{D}\tag{公式19}
Ltotal?D?=Lreal?D?+Lfake?D?+λLgpD?(公式19)
其中,
λ
\lambda
λ是一个超参数。
训练WANG-GP模型
mnist_bldr = tfds.builder('mnist')
mnist_bldr.download_and_prepare()
mnist = mnist_bldr.as_dataset(shuffle_files=False)
def preprocess(ex, mode='uniform'):
image = ex['image']
image = tf.image.convert_image_dtype(image, tf.float32)
image = image*2 - 1.0
if mode == 'uniform':
input_z = tf.random.uniform(
shape=(z_size,), minval=-1.0, maxval=1.0)
elif mode == 'normal':
input_z = tf.random.normal(shape=(z_size,))
return input_z, image
num_epochs = 100
batch_size = 128
image_size = (28, 28)
z_size = 20
mode_z = 'uniform'
lambda_gp = 10.0
tf.random.set_seed(1)
np.random.seed(1)
mnist_trainset = mnist['train']
mnist_trainset = mnist_trainset.map(preprocess)
mnist_trainset = mnist_trainset.shuffle(10000)
mnist_trainset = mnist_trainset.batch(
batch_size, drop_remainder=True)
with tf.device(device_name):
gen_model = make_dcgan_generator()
gen_model.build(input_shape=(None, z_size))
gen_model.summary()
disc_model = make_dcgan_discriminator()
disc_model.build(input_shape=(None, np.prod(image_size)))
disc_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 6272) 125440
_________________________________________________________________
batch_normalization_7 (Batch (None, 6272) 25088
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 6272) 0
_________________________________________________________________
reshape_2 (Reshape) (None, 7, 7, 128) 0
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 7, 7, 128) 409600
_________________________________________________________________
batch_normalization_8 (Batch (None, 7, 7, 128) 512
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 7, 7, 128) 0
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 14, 14, 64) 204800
_________________________________________________________________
batch_normalization_9 (Batch (None, 14, 14, 64) 256
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 14, 14, 64) 0
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 28, 28, 32) 51200
_________________________________________________________________
batch_normalization_10 (Batc (None, 28, 28, 32) 128
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None, 28, 28, 32) 0
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 28, 28, 1) 800
=================================================================
Total params: 817,824
Trainable params: 804,832
Non-trainable params: 12,992
_________________________________________________________________
Model: "sequential_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_4 (Conv2D) (None, 28, 28, 64) 1664
_________________________________________________________________
batch_normalization_11 (Batc (None, 28, 28, 64) 256
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None, 28, 28, 64) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 14, 14, 128) 204928
_________________________________________________________________
batch_normalization_12 (Batc (None, 14, 14, 128) 512
_________________________________________________________________
leaky_re_lu_12 (LeakyReLU) (None, 14, 14, 128) 0
_________________________________________________________________
dropout_2 (Dropout) (None, 14, 14, 128) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 7, 7, 256) 819456
_________________________________________________________________
batch_normalization_13 (Batc (None, 7, 7, 256) 1024
_________________________________________________________________
leaky_re_lu_13 (LeakyReLU) (None, 7, 7, 256) 0
_________________________________________________________________
dropout_3 (Dropout) (None, 7, 7, 256) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 1, 1, 1) 12545
_________________________________________________________________
reshape_3 (Reshape) (None, 1) 0
=================================================================
Total params: 1,040,385
Trainable params: 1,039,489
Non-trainable params: 896
_________________________________________________________________
import time
g_optimizer = tf.keras.optimizers.Adam(0.0002)
d_optimizer = tf.keras.optimizers.Adam(0.0002)
if mode_z == 'uniform':
fixed_z = tf.random.uniform(
shape=(batch_size, z_size),
minval=-1, maxval=1)
elif mode_z == 'normal':
fixed_z = tf.random.normal(
shape=(batch_size, z_size))
def create_samples(g_model, input_z):
g_output = g_model(input_z, training=False)
images = tf.reshape(g_output, (batch_size, *image_size))
return (images+1)/2.0
all_losses = []
epoch_samples = []
start_time = time.time()
for epoch in range(1, num_epochs+1):
epoch_losses = []
for i,(input_z,input_real) in enumerate(mnist_trainset):
with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
g_output = gen_model(input_z, training=True)
d_critics_real = disc_model(input_real, training=True)
d_critics_fake = disc_model(g_output, training=True)
g_loss = -tf.math.reduce_mean(d_critics_fake)
d_loss_real = -tf.math.reduce_mean(d_critics_real)
d_loss_fake = tf.math.reduce_mean(d_critics_fake)
d_loss = d_loss_real + d_loss_fake
with tf.GradientTape() as gp_tape:
alpha = tf.random.uniform(
shape=[d_critics_real.shape[0], 1, 1, 1],
minval=0.0, maxval=1.0)
interpolated = (
alpha*input_real + (1-alpha)*g_output)
gp_tape.watch(interpolated)
d_critics_intp = disc_model(interpolated)
grads_intp = gp_tape.gradient(
d_critics_intp, [interpolated,])[0]
grads_intp_l2 = tf.sqrt(
tf.reduce_sum(tf.square(grads_intp), axis=[1, 2, 3]))
grad_penalty = tf.reduce_mean(tf.square(grads_intp_l2 - 1.0))
d_loss = d_loss + lambda_gp*grad_penalty
d_grads = d_tape.gradient(d_loss, disc_model.trainable_variables)
d_optimizer.apply_gradients(
grads_and_vars=zip(d_grads, disc_model.trainable_variables))
g_grads = g_tape.gradient(g_loss, gen_model.trainable_variables)
g_optimizer.apply_gradients(
grads_and_vars=zip(g_grads, gen_model.trainable_variables))
epoch_losses.append(
(g_loss.numpy(), d_loss.numpy(),
d_loss_real.numpy(), d_loss_fake.numpy()))
all_losses.append(epoch_losses)
print('Epoch {:-3d} | ET {:.2f} min | Avg Losses >>'
' G/D {:6.2f}/{:6.2f} [D-Real: {:6.2f} D-Fake: {:6.2f}]'
.format(epoch, (time.time() - start_time)/60,
*list(np.mean(all_losses[-1], axis=0)))
)
epoch_samples.append(
create_samples(gen_model, fixed_z).numpy()
)
Epoch 1 | ET 1.58 min | Avg Losses >> G/D 186.47/-305.89 [D-Real: -204.71 D-Fake: -186.47]
Epoch 2 | ET 3.17 min | Avg Losses >> G/D 125.83/-23.81 [D-Real: -88.27 D-Fake: -125.83]
Epoch 3 | ET 4.77 min | Avg Losses >> G/D 81.43/ -2.47 [D-Real: 29.42 D-Fake: -81.43]
Epoch 4 | ET 6.37 min | Avg Losses >> G/D 43.89/ -6.47 [D-Real: -14.41 D-Fake: -43.89]
Epoch 5 | ET 7.97 min | Avg Losses >> G/D 37.12/ -5.75 [D-Real: 4.77 D-Fake: -37.12]
Epoch 6 | ET 9.56 min | Avg Losses >> G/D 37.59/-10.89 [D-Real: 11.12 D-Fake: -37.59]
Epoch 7 | ET 11.16 min | Avg Losses >> G/D 48.30/-13.90 [D-Real: 32.22 D-Fake: -48.30]
Epoch 8 | ET 12.76 min | Avg Losses >> G/D 43.40/-20.73 [D-Real: 19.27 D-Fake: -43.40]
Epoch 9 | ET 14.36 min | Avg Losses >> G/D 51.94/-32.62 [D-Real: 10.19 D-Fake: -51.94]
Epoch 10 | ET 15.97 min | Avg Losses >> G/D 54.99/-34.56 [D-Real: 10.05 D-Fake: -54.99]
Epoch 11 | ET 17.57 min | Avg Losses >> G/D 78.72/-43.58 [D-Real: 21.09 D-Fake: -78.72]
Epoch 12 | ET 19.16 min | Avg Losses >> G/D 83.79/-43.11 [D-Real: 26.02 D-Fake: -83.79]
Epoch 13 | ET 20.76 min | Avg Losses >> G/D 84.35/-33.70 [D-Real: 29.12 D-Fake: -84.35]
Epoch 14 | ET 22.36 min | Avg Losses >> G/D 78.19/-30.97 [D-Real: 34.40 D-Fake: -78.19]
Epoch 15 | ET 23.96 min | Avg Losses >> G/D 75.11/-44.77 [D-Real: 18.49 D-Fake: -75.11]
Epoch 16 | ET 25.56 min | Avg Losses >> G/D 83.66/-44.93 [D-Real: 25.06 D-Fake: -83.66]
Epoch 17 | ET 27.16 min | Avg Losses >> G/D 97.82/-41.78 [D-Real: 42.69 D-Fake: -97.82]
Epoch 18 | ET 28.76 min | Avg Losses >> G/D 102.77/-41.13 [D-Real: 44.20 D-Fake: -102.77]
Epoch 19 | ET 30.36 min | Avg Losses >> G/D 98.06/-45.39 [D-Real: 36.53 D-Fake: -98.06]
Epoch 20 | ET 31.96 min | Avg Losses >> G/D 92.78/-43.22 [D-Real: 39.53 D-Fake: -92.78]
Epoch 21 | ET 33.56 min | Avg Losses >> G/D 119.17/-45.35 [D-Real: 63.00 D-Fake: -119.17]
Epoch 22 | ET 35.15 min | Avg Losses >> G/D 167.63/-45.22 [D-Real: 109.72 D-Fake: -167.63]
Epoch 23 | ET 36.75 min | Avg Losses >> G/D 143.33/-43.10 [D-Real: 78.66 D-Fake: -143.33]
Epoch 24 | ET 38.35 min | Avg Losses >> G/D 165.30/-45.59 [D-Real: 108.96 D-Fake: -165.30]
Epoch 25 | ET 39.95 min | Avg Losses >> G/D 185.48/-49.65 [D-Real: 120.26 D-Fake: -185.48]
Epoch 26 | ET 41.54 min | Avg Losses >> G/D 177.84/-46.74 [D-Real: 111.94 D-Fake: -177.84]
Epoch 27 | ET 43.14 min | Avg Losses >> G/D 235.96/-57.86 [D-Real: 165.91 D-Fake: -235.96]
Epoch 28 | ET 44.74 min | Avg Losses >> G/D 246.81/-30.65 [D-Real: 198.56 D-Fake: -246.81]
Epoch 29 | ET 46.34 min | Avg Losses >> G/D 296.19/-36.56 [D-Real: 256.34 D-Fake: -296.19]
Epoch 30 | ET 47.94 min | Avg Losses >> G/D 317.43/-51.34 [D-Real: 260.19 D-Fake: -317.43]
Epoch 31 | ET 49.54 min | Avg Losses >> G/D 343.99/-47.43 [D-Real: 285.17 D-Fake: -343.99]
Epoch 32 | ET 51.14 min | Avg Losses >> G/D 367.00/-46.06 [D-Real: 303.26 D-Fake: -367.00]
Epoch 33 | ET 52.74 min | Avg Losses >> G/D 359.17/ 15.57 [D-Real: 311.80 D-Fake: -359.17]
Epoch 34 | ET 54.35 min | Avg Losses >> G/D 306.75/-28.72 [D-Real: 252.32 D-Fake: -306.75]
Epoch 35 | ET 55.94 min | Avg Losses >> G/D 338.60/-47.47 [D-Real: 282.77 D-Fake: -338.60]
Epoch 36 | ET 57.54 min | Avg Losses >> G/D 348.70/-51.16 [D-Real: 285.71 D-Fake: -348.70]
Epoch 37 | ET 59.14 min | Avg Losses >> G/D 339.03/-42.02 [D-Real: 291.79 D-Fake: -339.03]
Epoch 38 | ET 60.73 min | Avg Losses >> G/D 384.91/-48.19 [D-Real: 321.08 D-Fake: -384.91]
Epoch 39 | ET 62.34 min | Avg Losses >> G/D 377.89/-46.81 [D-Real: 322.20 D-Fake: -377.89]
Epoch 40 | ET 63.94 min | Avg Losses >> G/D 356.70/-41.24 [D-Real: 307.53 D-Fake: -356.70]
Epoch 41 | ET 65.54 min | Avg Losses >> G/D 352.23/-36.78 [D-Real: 312.28 D-Fake: -352.23]
Epoch 42 | ET 67.14 min | Avg Losses >> G/D 512.82/-54.96 [D-Real: 452.57 D-Fake: -512.82]
Epoch 43 | ET 68.74 min | Avg Losses >> G/D 525.44/-58.37 [D-Real: 451.38 D-Fake: -525.44]
Epoch 44 | ET 70.34 min | Avg Losses >> G/D 537.93/-46.99 [D-Real: 471.04 D-Fake: -537.93]
Epoch 45 | ET 71.94 min | Avg Losses >> G/D 586.10/-53.15 [D-Real: 523.73 D-Fake: -586.10]
Epoch 46 | ET 73.53 min | Avg Losses >> G/D 631.31/-61.20 [D-Real: 556.68 D-Fake: -631.31]
Epoch 47 | ET 75.14 min | Avg Losses >> G/D 644.35/-66.45 [D-Real: 571.78 D-Fake: -644.35]
Epoch 48 | ET 76.74 min | Avg Losses >> G/D 766.90/-68.34 [D-Real: 683.32 D-Fake: -766.90]
Epoch 49 | ET 78.33 min | Avg Losses >> G/D 715.37/-11.67 [D-Real: 654.83 D-Fake: -715.37]
Epoch 50 | ET 79.93 min | Avg Losses >> G/D 707.44/-65.47 [D-Real: 631.52 D-Fake: -707.44]
Epoch 51 | ET 81.53 min | Avg Losses >> G/D 682.43/-77.68 [D-Real: 600.63 D-Fake: -682.43]
Epoch 52 | ET 83.12 min | Avg Losses >> G/D 883.31/-74.31 [D-Real: 797.64 D-Fake: -883.31]
Epoch 53 | ET 84.72 min | Avg Losses >> G/D 828.10/ -4.36 [D-Real: 771.68 D-Fake: -828.10]
Epoch 54 | ET 86.32 min | Avg Losses >> G/D 916.28/-18.10 [D-Real: 871.71 D-Fake: -916.28]
Epoch 55 | ET 87.92 min | Avg Losses >> G/D 890.63/-52.51 [D-Real: 832.85 D-Fake: -890.63]
Epoch 56 | ET 89.51 min | Avg Losses >> G/D 706.38/-60.17 [D-Real: 642.28 D-Fake: -706.38]
Epoch 57 | ET 91.11 min | Avg Losses >> G/D 841.65/-93.87 [D-Real: 740.63 D-Fake: -841.65]
Epoch 58 | ET 92.71 min | Avg Losses >> G/D 1030.48/-117.84 [D-Real: 902.95 D-Fake: -1030.48]
Epoch 59 | ET 94.31 min | Avg Losses >> G/D 965.95/-89.45 [D-Real: 867.79 D-Fake: -965.95]
Epoch 60 | ET 95.91 min | Avg Losses >> G/D 1197.10/-105.03 [D-Real: 1055.86 D-Fake: -1197.10]
Epoch 61 | ET 97.51 min | Avg Losses >> G/D 1094.58/-125.98 [D-Real: 946.71 D-Fake: -1094.58]
Epoch 62 | ET 99.11 min | Avg Losses >> G/D 1158.20/-130.87 [D-Real: 1009.98 D-Fake: -1158.20]
Epoch 63 | ET 100.70 min | Avg Losses >> G/D 1015.38/-102.59 [D-Real: 907.90 D-Fake: -1015.38]
Epoch 64 | ET 102.30 min | Avg Losses >> G/D 1452.54/-183.59 [D-Real: 1250.75 D-Fake: -1452.54]
Epoch 65 | ET 103.90 min | Avg Losses >> G/D 1532.75/-176.89 [D-Real: 1327.26 D-Fake: -1532.75]
Epoch 66 | ET 105.50 min | Avg Losses >> G/D 1372.27/-184.02 [D-Real: 1173.63 D-Fake: -1372.27]
Epoch 67 | ET 107.10 min | Avg Losses >> G/D 1476.47/-151.54 [D-Real: 1286.28 D-Fake: -1476.47]
Epoch 68 | ET 108.70 min | Avg Losses >> G/D 1337.81/-165.44 [D-Real: 1155.44 D-Fake: -1337.81]
Epoch 69 | ET 110.30 min | Avg Losses >> G/D 1868.09/-265.95 [D-Real: 1564.99 D-Fake: -1868.09]
Epoch 70 | ET 111.90 min | Avg Losses >> G/D 1815.22/155.57 [D-Real: 1589.90 D-Fake: -1815.22]
Epoch 71 | ET 113.50 min | Avg Losses >> G/D 2016.40/-162.41 [D-Real: 1828.58 D-Fake: -2016.40]
Epoch 72 | ET 115.10 min | Avg Losses >> G/D 2118.12/-280.26 [D-Real: 1827.30 D-Fake: -2118.12]
Epoch 73 | ET 116.69 min | Avg Losses >> G/D 2143.48/-318.27 [D-Real: 1818.62 D-Fake: -2143.48]
Epoch 74 | ET 118.29 min | Avg Losses >> G/D 2188.90/-349.97 [D-Real: 1830.72 D-Fake: -2188.90]
Epoch 75 | ET 119.90 min | Avg Losses >> G/D 2362.95/-404.66 [D-Real: 1941.11 D-Fake: -2362.95]
Epoch 76 | ET 121.50 min | Avg Losses >> G/D 2443.07/-427.19 [D-Real: 2004.03 D-Fake: -2443.07]
Epoch 77 | ET 123.10 min | Avg Losses >> G/D 2404.13/-391.35 [D-Real: 1998.95 D-Fake: -2404.13]
Epoch 78 | ET 124.70 min | Avg Losses >> G/D 2562.10/-411.77 [D-Real: 2129.07 D-Fake: -2562.10]
Epoch 79 | ET 126.30 min | Avg Losses >> G/D 2826.90/-540.18 [D-Real: 2257.65 D-Fake: -2826.90]
Epoch 80 | ET 127.90 min | Avg Losses >> G/D 2765.97/-489.78 [D-Real: 2263.39 D-Fake: -2765.97]
Epoch 81 | ET 129.50 min | Avg Losses >> G/D 2775.33/-561.67 [D-Real: 2196.64 D-Fake: -2775.33]
Epoch 82 | ET 131.10 min | Avg Losses >> G/D 3329.70/-599.79 [D-Real: 2616.69 D-Fake: -3329.70]
Epoch 83 | ET 132.70 min | Avg Losses >> G/D 3475.59/-675.97 [D-Real: 2752.62 D-Fake: -3475.59]
Epoch 84 | ET 134.29 min | Avg Losses >> G/D 3634.00/-782.03 [D-Real: 2835.74 D-Fake: -3634.00]
Epoch 85 | ET 135.90 min | Avg Losses >> G/D 3804.03/-801.69 [D-Real: 2961.38 D-Fake: -3804.03]
Epoch 86 | ET 137.50 min | Avg Losses >> G/D 3938.59/-868.39 [D-Real: 3045.14 D-Fake: -3938.59]
Epoch 87 | ET 139.10 min | Avg Losses >> G/D 3951.56/-853.02 [D-Real: 3086.04 D-Fake: -3951.56]
Epoch 88 | ET 140.70 min | Avg Losses >> G/D 3248.74/-723.00 [D-Real: 2495.87 D-Fake: -3248.74]
Epoch 89 | ET 142.29 min | Avg Losses >> G/D 4644.86/-1046.76 [D-Real: 3557.25 D-Fake: -4644.86]
Epoch 90 | ET 143.89 min | Avg Losses >> G/D 4885.19/-1035.63 [D-Real: 3767.10 D-Fake: -4885.19]
Epoch 91 | ET 145.49 min | Avg Losses >> G/D 3802.31/-880.71 [D-Real: 2886.02 D-Fake: -3802.31]
Epoch 92 | ET 147.09 min | Avg Losses >> G/D 4992.27/-1107.25 [D-Real: 3849.27 D-Fake: -4992.27]
Epoch 93 | ET 148.70 min | Avg Losses >> G/D 5438.34/-1285.65 [D-Real: 4107.10 D-Fake: -5438.34]
Epoch 94 | ET 150.30 min | Avg Losses >> G/D 5705.12/-1254.32 [D-Real: 4341.58 D-Fake: -5705.12]
Epoch 95 | ET 151.90 min | Avg Losses >> G/D 6116.86/-1472.97 [D-Real: 4579.57 D-Fake: -6116.86]
Epoch 96 | ET 153.50 min | Avg Losses >> G/D 6402.22/-1595.38 [D-Real: 4765.93 D-Fake: -6402.22]
Epoch 97 | ET 155.10 min | Avg Losses >> G/D 6487.41/-1598.25 [D-Real: 4807.29 D-Fake: -6487.41]
Epoch 98 | ET 156.70 min | Avg Losses >> G/D 5330.31/-1202.23 [D-Real: 4032.31 D-Fake: -5330.31]
Epoch 99 | ET 158.30 min | Avg Losses >> G/D 6946.26/-1631.31 [D-Real: 5270.38 D-Fake: -6946.26]
Epoch 100 | ET 159.90 min | Avg Losses >> G/D 7301.76/-1722.93 [D-Real: 5541.48 D-Fake: -7301.76]
import itertools
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(1, 1, 1)
g_losses = [item[0] for item in itertools.chain(*all_losses)]
d_losses = [item[1] for item in itertools.chain(*all_losses)]
plt.plot(g_losses, label='Generator loss', alpha=0.95)
plt.plot(d_losses, label='Discriminator loss', alpha=0.95)
plt.legend(fontsize=20)
ax.set_xlabel('Iteration', size=15)
ax.set_ylabel('Loss', size=15)
epochs = np.arange(1, 101)
epoch2iter = lambda e: e*len(all_losses[-1])
epoch_ticks = [1, 20, 40, 60, 80, 100]
newpos = [epoch2iter(e) for e in epoch_ticks]
ax2 = ax.twiny()
ax2.set_xticks(newpos)
ax2.set_xticklabels(epoch_ticks)
ax2.xaxis.set_ticks_position('bottom')
ax2.xaxis.set_label_position('bottom')
ax2.spines['bottom'].set_position(('outward', 60))
ax2.set_xlabel('Epoch', size=15)
ax2.set_xlim(ax.get_xlim())
ax.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='major', labelsize=15)
plt.show()
selected_epochs = [1, 2, 4, 10, 50, 100]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
for j in range(5):
ax = fig.add_subplot(6, 5, i*5+j+1)
ax.set_xticks([])
ax.set_yticks([])
if j == 0:
ax.text(
-0.06, 0.5, 'Epoch {}'.format(e),
rotation=90, size=18, color='red',
horizontalalignment='right',
verticalalignment='center',
transform=ax.transAxes)
image = epoch_samples[e-1][j]
ax.imshow(image, cmap='gray_r')
plt.show()
模式碰撞Mode collapse
由于GAN模型的对抗性,所以训练GAN模型非常困难。训练GAN失败的一个常见原因是生成器陷入了一个小的子空间,以至于它学习生成的都是类似的样本 。
几何表示如下:
Image(filename='images/17_16.png', width=600)
除了之前看到的梯度消失和梯度爆炸问题,实际上还有一些因素也会使得GAN模型的训练变得困难。
这里列举出一些文献中提到的训练技巧:
mini-batch discrimination:
将由真实样本或者生成样本所组成的批次分别提供给判别器;
特征匹配:
在特征匹配中,对生成器的目标函数做了一点修改,增加了一个额外项,从而使得基于判别器的中间表示(特征映射)的原始图像和生成图像之间的差异最小化;
|