IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 利用转置卷积与WGAN提升图像数据生成的质量 -> 正文阅读

[人工智能]利用转置卷积与WGAN提升图像数据生成的质量

通过使用卷积与Wasserstein GAN提升生成图像的质量

本文将实现deep convolutional Gan(DCGAN),同时也将实现Wasserstein GAN(WGAN)

本文使用到的一些技术:

  • 转置卷积Transposed convolution

  • 批量归一化Batch Normalization

  • WGAN

  • Gradient penalty

from IPython.display import Image
%matplotlib inline

转置卷积

虽然卷积运算常用于对特征空间进行下采样(例如,通过将步长设置为2,或者通过在卷积层之后添加池层),但转置卷积运算通常用于对特征空间进行上采样

转置卷积的目的是恢复输入矩阵的形状,而非恢复实际矩阵值

详细来讲:原始输入size为 n × n n\times n n×n,经过二维卷积操作之后得到的特征映射图的size为 m × m m\times m m×m。转置卷积的作用就是从size m × m m \times m m×m的特征图重新得到size为 n × n n \times n n×n的特征图。集合表示如下:

Image(filename='images/17_09.png', width=700)

在这里插入图片描述

转置卷积与逆卷积(Transposed convolution and deconvolution)

转置卷积有时候也被称作小步长卷积。深度学习里面与转置卷积非常接近的一种卷积操作叫做逆卷积deconvolution

实际上,逆卷积被定义为常规卷积操作 f f f的逆运算。给定一个特征图feature map x \boldsymbol{x} x,权重为 w \boldsymbol{w} w,卷积操作后得到特征映射图为 x ′ \boldsymbol{x}^{\prime} x。公式表示如下:

f w ( x ) = x ′ (公式1) f_{w}(\boldsymbol{x})=\boldsymbol{x}^{\prime}\tag{公式1} fw?(x)=x(1)

对于逆卷积运算,记作 f ? 1 f^{-1} f?1,公式表示如下:

f w ? 1 ( f ( x ) ) = x (公式2) f_{w}^{-1}(f(x))=x\tag{公式2} fw?1?(f(x))=x(2)

转置卷积仅在意于恢复特征空间的维度,而不是实际值。

Image(filename='images/17_10.png', width=700)

在这里插入图片描述

批量归一化Batch normalization

主要思想是将层输入标准化,并且在训练期间防止其发生变化,从而实现更快更好地收敛。

假设一个四维张量 Z \boldsymbol{Z} Z,其shape为 [ m × h × w × c ] [m \times h \times w \times c] [m×h×w×c],对其进行卷积操作之后得到特征映射图。

这里的 m m m代表的是batch中的样本数量,或者称为batch size; h × w h\times w h×w代表的是特征图的空间维度, c c c代表的是通道数channels。

BatchNorm可以总结为如下三个步骤:

  • 对每个小批量的网络输入计算均值和标准差

μ B = 1 m × h × w ∑ i , j , k Z [ i , j , k , ] , σ B 2 = 1 m × h × w ∑ i , j , k ( Z [ i , j , k , ] ? μ B ) 2 , ?其中? μ B ?和? σ B 2 ?都有size?c? (公式3) \boldsymbol{\mu}_{B}=\frac{1}{m \times h \times w} \sum_{i, j, k} \boldsymbol{Z}^{[i, j, k,]}, \boldsymbol{\sigma}_{B}^{2}=\frac{1}{m \times h \times w} \sum_{i, j, k}\left(\boldsymbol{Z}^{[i, j, k,]}-\boldsymbol{\mu}_{B}\right)^{2}, \text { 其中 } \boldsymbol{\mu}_{B} \text { 和 } \boldsymbol{\sigma}_{B}^{2}\text { 都有size c }\tag{公式3} μB?=m×h×w1?i,j,k?Z[i,j,k,],σB2?=m×h×w1?i,j,k?(Z[i,j,k,]?μB?)2,?其中?μB???σB2??都有size?c?(3)

  • 对批量batch中的所有样本进行标准化计算,如下:

Z s t d [ i ] = Z [ i ] ? μ B σ B + ? (公式4) \boldsymbol{Z}_{\mathrm{std}}^{[i]}=\frac{\boldsymbol{Z}^{[i]}-\boldsymbol{\mu}_{B}}{\boldsymbol{\sigma}_{B}+\epsilon}\tag{公式4} Zstd[i]?=σB?+?Z[i]?μB??(4)其中, ? \epsilon ?是一个确保数学稳定性(避免分母为零)的一个比较小的数。

  • 使用两个可学习的参数向量缩放和转移归一化的网络输入,向量分别为 γ and? β \gamma_{\text {and }} \boldsymbol{\beta} γand??β.表示如下:

A p r e [ i ] = γ Z s t d [ i ] + β (公式5) A_{\mathrm{pre}}^{[i]}=\gamma Z_{\mathrm{std}}^{[i]}+\boldsymbol{\beta}\tag{公式5} Apre[i]?=γZstd[i]?+β(5)

几何表示如下:

Image(filename='images/17_11.png', width=700)

在这里插入图片描述

Tensorflow框架提供了tf.keras.layers.BatchNormalization(),因此可以直接调用来定义模型。

同时上面的两个可学习参数 γ and? β \gamma_{\text {and }} \boldsymbol{\beta} γand??β可以通过training=Falsetraining=True来实现对是否可以学习进行控制。

为什么BatchNorm有助于优化

起初,BatchNorm的开发是为了减少所谓的内部协方差偏移,它代表的是在训练过程中,由于网络参数更新而导致的层的激活值分布发生的变化。

考虑一个固定的batch,其在epoch为1的时候通过网络,并将该批次的每一层的activations记录下来。

在遍历整个训练数据集并更新模型参数后,开始第二轮的epoch,其中,先前固定的patch通过网络。然后比较第一个和第二个epochs的层激活。

由于网络参数发生了变化,因此我们观察到activations也发生了变化,这种现象称为内部协方差偏移internal covariance shift。他被确信会减慢NN的训练速度。

但是后期有学者S.Santurkar研究指出,BatchNorm的使用,会导致损失函数的曲面更加光滑,从而在处理非凸优化问题中有着更加稳健的效果。

相关文献参考链接

实现generator和discriminator

Image(filename='images/17_12.png', width=700)

在这里插入图片描述

初始化输入向量 z \boldsymbol{z} z作为输入,同时使用全连接而将size提升到6272,然后reshape到3阶张量,形状为7x7x128(空间维度为7x7,128通道)。

然后经过一系列的转置卷积tf.keras.layers.Conv2DTransposed()实现上采样,直到空间维度达到28x28。然后,通道数每次减半,但最后一个除外。

最后一个仅使用一个输出滤波器来产生一个灰度图像。

每一个转置卷积层都紧跟BatchNorm和Leaky ReLU激活函数,但最后一个除外,其使用的是tanh激活函数。

Image(filename='images/17_13.png', width=700)

在这里插入图片描述

判别器接收的是28x28x1的图像,其经过了4个卷积层。

通常不建议在BatchNorm层之后的层中使用偏置单元。在这种情况下,偏置单元是多余的,因为BatchNorm已经有了一个偏移参数

这个可以通过‘use_bial=False’进行设置。

  • Setting up the Google Colab
#! pip install -q tensorflow-gpu==2.0.0-beta1
#from google.colab import drive
#drive.mount('/content/drive/')
import tensorflow as tf
tf.config.list_physical_devices('GPU')
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
import tensorflow as tf


print(tf.__version__)

print("GPU Available:", tf.config.list_physical_devices('GPU'))

if tf.test.is_gpu_available():
    device_name = tf.test.gpu_device_name()

else:
    device_name = 'CPU:0--测试名称'
    
print(device_name)
2.1.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
WARNING:tensorflow:From <ipython-input-11-f834e5116ee5>:8: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
/device:GPU:0
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
def make_dcgan_generator(
        z_size=20, 
        output_size=(28, 28, 1),
        n_filters=128, 
        n_blocks=2):
    size_factor = 2**n_blocks
    hidden_size = (
        output_size[0]//size_factor, 
        output_size[1]//size_factor
    )
    
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(z_size,)),
        
        tf.keras.layers.Dense(
            units=n_filters*np.prod(hidden_size), 
            use_bias=False),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Reshape(
            (hidden_size[0], hidden_size[1], n_filters)),
    
        tf.keras.layers.Conv2DTranspose(
            filters=n_filters, kernel_size=(5, 5), strides=(1, 1),
            padding='same', use_bias=False),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU()
    ])
        
    nf = n_filters
    for i in range(n_blocks):
        nf = nf // 2
        model.add(
            tf.keras.layers.Conv2DTranspose(
                filters=nf, kernel_size=(5, 5), strides=(2, 2),
                padding='same', use_bias=False))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.LeakyReLU())
                
    model.add(
        tf.keras.layers.Conv2DTranspose(
            filters=output_size[2], kernel_size=(5, 5), 
            strides=(1, 1), padding='same', use_bias=False, 
            activation='tanh'))
        
    return model

def make_dcgan_discriminator(
        input_size=(28, 28, 1),
        n_filters=64, 
        n_blocks=2):
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=input_size),
        tf.keras.layers.Conv2D(
            filters=n_filters, kernel_size=5, 
            strides=(1, 1), padding='same'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU()
    ])
    
    nf = n_filters
    for i in range(n_blocks):
        nf = nf*2
        model.add(
            tf.keras.layers.Conv2D(
                filters=nf, kernel_size=(5, 5), 
                strides=(2, 2),padding='same'))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.LeakyReLU())
        model.add(tf.keras.layers.Dropout(0.3))
        
    model.add(tf.keras.layers.Conv2D(
            filters=1, kernel_size=(7, 7), padding='valid'))
    
    model.add(tf.keras.layers.Reshape((1,)))
    
    return model
gen_model = make_dcgan_generator()
gen_model.summary()

disc_model = make_dcgan_discriminator()
disc_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 6272)              125440    
_________________________________________________________________
batch_normalization (BatchNo (None, 6272)              25088     
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 6272)              0         
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 7, 7, 128)         409600    
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 64)        204800    
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 32)        51200     
_________________________________________________________________
batch_normalization_3 (Batch (None, 28, 28, 32)        128       
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 28, 28, 32)        0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 28, 28, 1)         800       
=================================================================
Total params: 817,824
Trainable params: 804,832
Non-trainable params: 12,992
_________________________________________________________________
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 28, 28, 64)        1664      
_________________________________________________________________
batch_normalization_4 (Batch (None, 28, 28, 64)        256       
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 128)       204928    
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 14, 14, 128)       0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 256)         819456    
_________________________________________________________________
batch_normalization_6 (Batch (None, 7, 7, 256)         1024      
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 7, 7, 256)         0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 1, 1, 1)           12545     
_________________________________________________________________
reshape_1 (Reshape)          (None, 1)                 0         
=================================================================
Total params: 1,040,385
Trainable params: 1,039,489
Non-trainable params: 896
_________________________________________________________________

度量两种分布差异的不同方法

使用 P ( x ) P(x) P(x) Q ( x ) Q(x) Q(x)代表随机变量 x x x的分布,几何表示如下:

Image(filename='images/17_14.png', width=700)

在这里插入图片描述

T V TV TV度量方法中,上确界(最小上界)函数supremum function代表的是:大于S中所有元素的值的最小值;

换句话说, s u p ( S ) sup(S) sup(S) S S S的最小上届。

  • TV 距离:衡量了两个分布对于所有数据点的最大差异;

  • EM 距离:可以理解为将一个分布转换为另一个分布的所需要做的最小工作量;EM距离中的下确界函数被替换了;

  • KL和JS散度来自于信息论领域。与JS散度相反,KL散度不对称,即 K L ( P ∥ Q ) ≠ K L ( Q ∥ P ) K L(P \| Q) \neq K L(Q \| P) KL(PQ)?=KL(QP)

下面是一些计算示例:

Image(filename='images/17_15.png', width=800)

在这里插入图片描述

KL散度和交叉熵之间的关系

KL散度衡量的是分布P相对于一个参考分布Q的相对熵,KL散度的计算公式为:

K L ( P ∥ Q ) = ? ∫ P ( x ) log ? ( Q ( x ) ) d x ? ( ? ∫ P ( x ) log ? ( P ( x ) ) ) d x (公式6) K L(P \| Q)=-\int P(x) \log (Q(x)) d x-\left(-\int P(x) \log (P(x))\right)d x\tag{公式6} KL(PQ)=?P(x)log(Q(x))dx?(?P(x)log(P(x)))dx(6)

对于离散分布:KL散度表示如下

K L ( P ∥ Q ) = ? ∑ i P ( x i ) P ( x i ) Q ( x i ) (公式7) K L(P \| Q)=-\sum_{i} P\left(x_{i}\right) \frac{P\left(x_{i}\right)}{Q\left(x_{i}\right)}\tag{公式7} KL(PQ)=?i?P(xi?)Q(xi?)P(xi?)?(7)

也可以类似表示为如下形式:

K L ( P ∥ Q ) = ? ∑ i P ( x i ) log ? ( Q ( x i ) ) ? ( ? ∑ i P ( x i ) log ? ( P ( x i ) ) ) (公式8) K L(P \| Q)=-\sum_{i} P\left(x_{i}\right) \log \left(Q\left(x_{i}\right)\right)-\left(-\sum_{i} P\left(x_{i}\right) \log \left(P\left(x_{i}\right)\right)\right)\tag{公式8} KL(PQ)=?i?P(xi?)log(Q(xi?))?(?i?P(xi?)log(P(xi?)))(8)

KL散度可以看做是P和Q之间的交叉熵减去P自身的熵:
K L ( P ∥ Q ) = H ( P , Q ) ? H ( P ) (公式9) K L(P \| Q)=H(P, Q)-H(P)\tag{公式9} KL(PQ)=H(P,Q)?H(P)(9)

相关论文表明,使用JS散度训练GAN模型可能会有问题,建议使用EM距离

使用EM距离的好处

参考于文献---Wasserstein GAN:假设现有两个分布P和Q,他们是平行线。其中一条是y轴,另一条是 x = θ x=\theta x=θ,其中 θ > 0 \theta>0 θ>0

此时,计算KL、TV,JS则有:

K L ( P ∥ Q ) = + ∞ , T V ( P , Q ) = 1 , ?and? J S ( P , Q ) = 1 2 log ? 2 (公式10) K L(P \| Q)=+\infty, T V(P, Q)=1, \text { and } J S(P, Q)=\frac{1}{2} \log 2\tag{公式10} KL(PQ)=+,TV(P,Q)=1,?and?JS(P,Q)=21?log2(10)

这样一来,上面所有的距离均与 θ \theta θ无关,而对应的EM距离为:

E M ( P , Q ) = ∣ θ ∣ (公式11) E M(P, Q)=|\theta|\tag{公式11} EM(P,Q)=θ(11)

在GANs中使用EM距离作为度量标准

假设 P r P_r Pr?为真实样本的分布, P g P_g Pg?为生成样本的分布。其分别对应于EM距离公式中的 P P P Q Q Q

EM距离的计算是一个最优化问题,因此对计算能力有较大的需求。所幸,可以通过使用Kantorovich-Rubinstein duality进行简化:

W ( P r , P g ) = sup ? ∥ f ∥ L ≤ 1 E u ∈ P r [ f ( u ) ] ? E v ∈ P g [ f ( v ) ] (公式12) W\left(P_{r}, P_{g}\right)=\sup _{\|f\|_{L} \leq 1} E_{u \in P_{r}}[f(u)]-E_{v \in P_{g}}[f(v)]\tag{公式12} W(Pr?,Pg?)=fL?1sup?EuPr??[f(u)]?EvPg??[f(v)](12)

这里的上确界supremum是指所有的1-Lipschitz连续函数,表示为: ∥ f ∥ L ≤ 1 \|f\|_{L} \leq 1 fL?1

Lipschitz continuity

基于1-Lipschitz连续型,函数 f f f必须满足如下条件:

∣ f ( x 1 ) ? f ( x 2 ) ∣ ≤ ∣ x 1 ? x 2 ∣ (公式13) \left|f\left(x_{1}\right)-f\left(x_{2}\right)\right| \leq\left|x_{1}-x_{2}\right|\tag{公式13} f(x1?)?f(x2?)x1??x2?(13)

进一步,一个实际的函数 f : R → R f: R\rightarrow R f:RR可能满足如下条件:

∣ f ( x 1 ) ? f ( x 2 ) ∣ ≤ K ∣ x 1 ? x 2 ∣ (公式14) \left|f\left(x_{1}\right)-f\left(x_{2}\right)\right| \leq K\left|x_{1}-x_{2}\right|\tag{公式14} f(x1?)?f(x2?)Kx1??x2?(14)

使用Wasserstein distance训练GAN模型,对于生成器和判别器的损失分别定义如下:

对于判别器----真实样本:

L real? D = ? 1 N ∑ i D ( x i ) (公式15) L_{\text {real }}^{D}=-\frac{1}{N} \sum_{i} D\left(\boldsymbol{x}_{i}\right)\tag{公式15} Lreal?D?=?N1?i?D(xi?)(15)

对于判别器----生成样本:

L fake? D = 1 N ∑ i D ( G ( z i ) ) (公式16) L_{\text {fake }}^{D}=\frac{1}{N} \sum_{i} D\left(G\left(\mathbf{z}_{i}\right)\right)\tag{公式16} Lfake?D?=N1?i?D(G(zi?))(16)

对于生成器:

L G = ? 1 N ∑ i D ( G ( z i ) ) (公式17) L^{G}=-\frac{1}{N} \sum_{i} D\left(G\left(\mathbf{z}_{i}\right)\right)\tag{公式17} LG=?N1?i?D(G(zi?))(17)

梯度惩罚Gradient penalty (GP)

带有梯度惩罚的GAN模型----WGAN-GP:

相关文献指出,在GAN网络中,使用权重裁剪可能会导致梯度消失或者梯度爆炸

因此本文没有采用权重裁剪,而尝试使用梯度惩罚(gradient penalty—GP),相对应的网络为WANG with gradient penalty—WANG-GP:

在每个iteration中添加GP的过程总结如下

  • 在给定的一个batch中,对于每个real和fake样本,选择一个随机数 α [ i ] \alpha^{[i]} α[i],其是通过从均匀分布中采样得到,因此 α [ i ] ∈ U ( 0 , 1 ) \alpha^{[i]} \in U(0,1) α[i]U(0,1).

  • 计算真实样本和生成样本(fake)之间的插值: x ? [ i ] = α x [ i ] + ( 1 ? α ) x ~ [ i ] \breve{\boldsymbol{x}}^{[i]}=\alpha \boldsymbol{x}^{[i]}+(1-\alpha) \widetilde{\boldsymbol{x}}^{[i]} x?[i]=αx[i]+(1?α)x [i],结果是产生了一批插值的例子;

  • 计算判别器网络相对于所有插值样本的输出: D ( x ˇ [ i ] ) D\left(\check{\boldsymbol{x}}^{[i]}\right) D(xˇ[i]);

  • 计算判别器万国相对于所有插值示例的梯度: ? x ? [ i ] D ( x ? [ i ] ) \nabla_{\breve{\boldsymbol{x}}^{[i]}} D\left(\breve{\boldsymbol{x}}^{[i]}\right) ?x?[i]?D(x?[i]);

  • 计算梯度惩罚GP:
    L g p D = 1 N ∑ i ( ∥ ? x ? [ i ] D ( x ? [ i ] ) ∥ 2 ? 1 ) 2 (公式18) L_{g p}^{D}=\frac{1}{N} \sum_{i}\left(\left\|\nabla_{\breve{x}^{[i]}} D\left(\breve{x}^{[i]}\right)\right\|_{2}-1\right)^{2}\tag{公式18} LgpD?=N1?i?(??x?[i]?D(x?[i])?2??1)2(18)

因此,对于判别器而言,总的损失为:

L total? D = L real? D + L fake? D + λ L g p D (公式19) L_{\text {total }}^{D}=L_{\text {real }}^{D}+L_{\text {fake }}^{D}+\lambda L_{g p}^{D}\tag{公式19} Ltotal?D?=Lreal?D?+Lfake?D?+λLgpD?(19)

其中, λ \lambda λ是一个超参数。

训练WANG-GP模型

mnist_bldr = tfds.builder('mnist')
mnist_bldr.download_and_prepare()
mnist = mnist_bldr.as_dataset(shuffle_files=False)

def preprocess(ex, mode='uniform'):
    image = ex['image']
    image = tf.image.convert_image_dtype(image, tf.float32)

    image = image*2 - 1.0
    if mode == 'uniform':
        input_z = tf.random.uniform(
            shape=(z_size,), minval=-1.0, maxval=1.0)
    elif mode == 'normal':
        input_z = tf.random.normal(shape=(z_size,))
    return input_z, image
num_epochs = 100
batch_size = 128
image_size = (28, 28)
z_size = 20
mode_z = 'uniform'
lambda_gp = 10.0

tf.random.set_seed(1)
np.random.seed(1)

## Set-up the dataset
mnist_trainset = mnist['train']
mnist_trainset = mnist_trainset.map(preprocess)

mnist_trainset = mnist_trainset.shuffle(10000)
mnist_trainset = mnist_trainset.batch(
    batch_size, drop_remainder=True)

## Set-up the model
with tf.device(device_name):
    gen_model = make_dcgan_generator()
    gen_model.build(input_shape=(None, z_size))
    gen_model.summary()

    disc_model = make_dcgan_discriminator()
    disc_model.build(input_shape=(None, np.prod(image_size)))
    disc_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 6272)              125440    
_________________________________________________________________
batch_normalization_7 (Batch (None, 6272)              25088     
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 6272)              0         
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 7, 7, 128)         409600    
_________________________________________________________________
batch_normalization_8 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 14, 14, 64)        204800    
_________________________________________________________________
batch_normalization_9 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 28, 28, 32)        51200     
_________________________________________________________________
batch_normalization_10 (Batc (None, 28, 28, 32)        128       
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 28, 28, 32)        0         
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 28, 28, 1)         800       
=================================================================
Total params: 817,824
Trainable params: 804,832
Non-trainable params: 12,992
_________________________________________________________________
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_4 (Conv2D)            (None, 28, 28, 64)        1664      
_________________________________________________________________
batch_normalization_11 (Batc (None, 28, 28, 64)        256       
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 14, 14, 128)       204928    
_________________________________________________________________
batch_normalization_12 (Batc (None, 14, 14, 128)       512       
_________________________________________________________________
leaky_re_lu_12 (LeakyReLU)   (None, 14, 14, 128)       0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 7, 7, 256)         819456    
_________________________________________________________________
batch_normalization_13 (Batc (None, 7, 7, 256)         1024      
_________________________________________________________________
leaky_re_lu_13 (LeakyReLU)   (None, 7, 7, 256)         0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 1, 1, 1)           12545     
_________________________________________________________________
reshape_3 (Reshape)          (None, 1)                 0         
=================================================================
Total params: 1,040,385
Trainable params: 1,039,489
Non-trainable params: 896
_________________________________________________________________
import time


## optimizers:
g_optimizer = tf.keras.optimizers.Adam(0.0002)
d_optimizer = tf.keras.optimizers.Adam(0.0002)

if mode_z == 'uniform':
    fixed_z = tf.random.uniform(
        shape=(batch_size, z_size),
        minval=-1, maxval=1)
elif mode_z == 'normal':
    fixed_z = tf.random.normal(
        shape=(batch_size, z_size))

def create_samples(g_model, input_z):
    g_output = g_model(input_z, training=False)
    images = tf.reshape(g_output, (batch_size, *image_size))    
    return (images+1)/2.0

all_losses = []
epoch_samples = []

start_time = time.time()

for epoch in range(1, num_epochs+1):
    epoch_losses = []
    for i,(input_z,input_real) in enumerate(mnist_trainset):
        
        ## Compute discriminator's loss and gradients:
        with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
            g_output = gen_model(input_z, training=True)
            
            d_critics_real = disc_model(input_real, training=True)
            d_critics_fake = disc_model(g_output, training=True)

            ## Compute generator's loss:
            g_loss = -tf.math.reduce_mean(d_critics_fake)

            ## Compute discriminator's losses
            d_loss_real = -tf.math.reduce_mean(d_critics_real)
            d_loss_fake =  tf.math.reduce_mean(d_critics_fake)
            d_loss = d_loss_real + d_loss_fake

            ## Gradient penalty:
            with tf.GradientTape() as gp_tape:
                alpha = tf.random.uniform(
                    shape=[d_critics_real.shape[0], 1, 1, 1], 
                    minval=0.0, maxval=1.0)
                interpolated = (
                    alpha*input_real + (1-alpha)*g_output)
                gp_tape.watch(interpolated)
                d_critics_intp = disc_model(interpolated)
            
            grads_intp = gp_tape.gradient(
                d_critics_intp, [interpolated,])[0]
            grads_intp_l2 = tf.sqrt(
                tf.reduce_sum(tf.square(grads_intp), axis=[1, 2, 3]))
            grad_penalty = tf.reduce_mean(tf.square(grads_intp_l2 - 1.0))
        
            d_loss = d_loss + lambda_gp*grad_penalty
        
        ## Optimization: Compute the gradients apply them
        d_grads = d_tape.gradient(d_loss, disc_model.trainable_variables)
        d_optimizer.apply_gradients(
            grads_and_vars=zip(d_grads, disc_model.trainable_variables))
        
        g_grads = g_tape.gradient(g_loss, gen_model.trainable_variables)
        g_optimizer.apply_gradients(
            grads_and_vars=zip(g_grads, gen_model.trainable_variables))

        epoch_losses.append(
            (g_loss.numpy(), d_loss.numpy(), 
             d_loss_real.numpy(), d_loss_fake.numpy()))
                    
    all_losses.append(epoch_losses)
    
    print('Epoch {:-3d} | ET {:.2f} min | Avg Losses >>'
          ' G/D {:6.2f}/{:6.2f} [D-Real: {:6.2f} D-Fake: {:6.2f}]'
          .format(epoch, (time.time() - start_time)/60, 
                  *list(np.mean(all_losses[-1], axis=0)))
    )
    
    epoch_samples.append(
        create_samples(gen_model, fixed_z).numpy()
    )
Epoch   1 | ET 1.58 min | Avg Losses >> G/D 186.47/-305.89 [D-Real: -204.71 D-Fake: -186.47]
Epoch   2 | ET 3.17 min | Avg Losses >> G/D 125.83/-23.81 [D-Real: -88.27 D-Fake: -125.83]
Epoch   3 | ET 4.77 min | Avg Losses >> G/D  81.43/ -2.47 [D-Real:  29.42 D-Fake: -81.43]
Epoch   4 | ET 6.37 min | Avg Losses >> G/D  43.89/ -6.47 [D-Real: -14.41 D-Fake: -43.89]
Epoch   5 | ET 7.97 min | Avg Losses >> G/D  37.12/ -5.75 [D-Real:   4.77 D-Fake: -37.12]
Epoch   6 | ET 9.56 min | Avg Losses >> G/D  37.59/-10.89 [D-Real:  11.12 D-Fake: -37.59]
Epoch   7 | ET 11.16 min | Avg Losses >> G/D  48.30/-13.90 [D-Real:  32.22 D-Fake: -48.30]
Epoch   8 | ET 12.76 min | Avg Losses >> G/D  43.40/-20.73 [D-Real:  19.27 D-Fake: -43.40]
Epoch   9 | ET 14.36 min | Avg Losses >> G/D  51.94/-32.62 [D-Real:  10.19 D-Fake: -51.94]
Epoch  10 | ET 15.97 min | Avg Losses >> G/D  54.99/-34.56 [D-Real:  10.05 D-Fake: -54.99]
Epoch  11 | ET 17.57 min | Avg Losses >> G/D  78.72/-43.58 [D-Real:  21.09 D-Fake: -78.72]
Epoch  12 | ET 19.16 min | Avg Losses >> G/D  83.79/-43.11 [D-Real:  26.02 D-Fake: -83.79]
Epoch  13 | ET 20.76 min | Avg Losses >> G/D  84.35/-33.70 [D-Real:  29.12 D-Fake: -84.35]
Epoch  14 | ET 22.36 min | Avg Losses >> G/D  78.19/-30.97 [D-Real:  34.40 D-Fake: -78.19]
Epoch  15 | ET 23.96 min | Avg Losses >> G/D  75.11/-44.77 [D-Real:  18.49 D-Fake: -75.11]
Epoch  16 | ET 25.56 min | Avg Losses >> G/D  83.66/-44.93 [D-Real:  25.06 D-Fake: -83.66]
Epoch  17 | ET 27.16 min | Avg Losses >> G/D  97.82/-41.78 [D-Real:  42.69 D-Fake: -97.82]
Epoch  18 | ET 28.76 min | Avg Losses >> G/D 102.77/-41.13 [D-Real:  44.20 D-Fake: -102.77]
Epoch  19 | ET 30.36 min | Avg Losses >> G/D  98.06/-45.39 [D-Real:  36.53 D-Fake: -98.06]
Epoch  20 | ET 31.96 min | Avg Losses >> G/D  92.78/-43.22 [D-Real:  39.53 D-Fake: -92.78]
Epoch  21 | ET 33.56 min | Avg Losses >> G/D 119.17/-45.35 [D-Real:  63.00 D-Fake: -119.17]
Epoch  22 | ET 35.15 min | Avg Losses >> G/D 167.63/-45.22 [D-Real: 109.72 D-Fake: -167.63]
Epoch  23 | ET 36.75 min | Avg Losses >> G/D 143.33/-43.10 [D-Real:  78.66 D-Fake: -143.33]
Epoch  24 | ET 38.35 min | Avg Losses >> G/D 165.30/-45.59 [D-Real: 108.96 D-Fake: -165.30]
Epoch  25 | ET 39.95 min | Avg Losses >> G/D 185.48/-49.65 [D-Real: 120.26 D-Fake: -185.48]
Epoch  26 | ET 41.54 min | Avg Losses >> G/D 177.84/-46.74 [D-Real: 111.94 D-Fake: -177.84]
Epoch  27 | ET 43.14 min | Avg Losses >> G/D 235.96/-57.86 [D-Real: 165.91 D-Fake: -235.96]
Epoch  28 | ET 44.74 min | Avg Losses >> G/D 246.81/-30.65 [D-Real: 198.56 D-Fake: -246.81]
Epoch  29 | ET 46.34 min | Avg Losses >> G/D 296.19/-36.56 [D-Real: 256.34 D-Fake: -296.19]
Epoch  30 | ET 47.94 min | Avg Losses >> G/D 317.43/-51.34 [D-Real: 260.19 D-Fake: -317.43]
Epoch  31 | ET 49.54 min | Avg Losses >> G/D 343.99/-47.43 [D-Real: 285.17 D-Fake: -343.99]
Epoch  32 | ET 51.14 min | Avg Losses >> G/D 367.00/-46.06 [D-Real: 303.26 D-Fake: -367.00]
Epoch  33 | ET 52.74 min | Avg Losses >> G/D 359.17/ 15.57 [D-Real: 311.80 D-Fake: -359.17]
Epoch  34 | ET 54.35 min | Avg Losses >> G/D 306.75/-28.72 [D-Real: 252.32 D-Fake: -306.75]
Epoch  35 | ET 55.94 min | Avg Losses >> G/D 338.60/-47.47 [D-Real: 282.77 D-Fake: -338.60]
Epoch  36 | ET 57.54 min | Avg Losses >> G/D 348.70/-51.16 [D-Real: 285.71 D-Fake: -348.70]
Epoch  37 | ET 59.14 min | Avg Losses >> G/D 339.03/-42.02 [D-Real: 291.79 D-Fake: -339.03]
Epoch  38 | ET 60.73 min | Avg Losses >> G/D 384.91/-48.19 [D-Real: 321.08 D-Fake: -384.91]
Epoch  39 | ET 62.34 min | Avg Losses >> G/D 377.89/-46.81 [D-Real: 322.20 D-Fake: -377.89]
Epoch  40 | ET 63.94 min | Avg Losses >> G/D 356.70/-41.24 [D-Real: 307.53 D-Fake: -356.70]
Epoch  41 | ET 65.54 min | Avg Losses >> G/D 352.23/-36.78 [D-Real: 312.28 D-Fake: -352.23]
Epoch  42 | ET 67.14 min | Avg Losses >> G/D 512.82/-54.96 [D-Real: 452.57 D-Fake: -512.82]
Epoch  43 | ET 68.74 min | Avg Losses >> G/D 525.44/-58.37 [D-Real: 451.38 D-Fake: -525.44]
Epoch  44 | ET 70.34 min | Avg Losses >> G/D 537.93/-46.99 [D-Real: 471.04 D-Fake: -537.93]
Epoch  45 | ET 71.94 min | Avg Losses >> G/D 586.10/-53.15 [D-Real: 523.73 D-Fake: -586.10]
Epoch  46 | ET 73.53 min | Avg Losses >> G/D 631.31/-61.20 [D-Real: 556.68 D-Fake: -631.31]
Epoch  47 | ET 75.14 min | Avg Losses >> G/D 644.35/-66.45 [D-Real: 571.78 D-Fake: -644.35]
Epoch  48 | ET 76.74 min | Avg Losses >> G/D 766.90/-68.34 [D-Real: 683.32 D-Fake: -766.90]
Epoch  49 | ET 78.33 min | Avg Losses >> G/D 715.37/-11.67 [D-Real: 654.83 D-Fake: -715.37]
Epoch  50 | ET 79.93 min | Avg Losses >> G/D 707.44/-65.47 [D-Real: 631.52 D-Fake: -707.44]
Epoch  51 | ET 81.53 min | Avg Losses >> G/D 682.43/-77.68 [D-Real: 600.63 D-Fake: -682.43]
Epoch  52 | ET 83.12 min | Avg Losses >> G/D 883.31/-74.31 [D-Real: 797.64 D-Fake: -883.31]
Epoch  53 | ET 84.72 min | Avg Losses >> G/D 828.10/ -4.36 [D-Real: 771.68 D-Fake: -828.10]
Epoch  54 | ET 86.32 min | Avg Losses >> G/D 916.28/-18.10 [D-Real: 871.71 D-Fake: -916.28]
Epoch  55 | ET 87.92 min | Avg Losses >> G/D 890.63/-52.51 [D-Real: 832.85 D-Fake: -890.63]
Epoch  56 | ET 89.51 min | Avg Losses >> G/D 706.38/-60.17 [D-Real: 642.28 D-Fake: -706.38]
Epoch  57 | ET 91.11 min | Avg Losses >> G/D 841.65/-93.87 [D-Real: 740.63 D-Fake: -841.65]
Epoch  58 | ET 92.71 min | Avg Losses >> G/D 1030.48/-117.84 [D-Real: 902.95 D-Fake: -1030.48]
Epoch  59 | ET 94.31 min | Avg Losses >> G/D 965.95/-89.45 [D-Real: 867.79 D-Fake: -965.95]
Epoch  60 | ET 95.91 min | Avg Losses >> G/D 1197.10/-105.03 [D-Real: 1055.86 D-Fake: -1197.10]
Epoch  61 | ET 97.51 min | Avg Losses >> G/D 1094.58/-125.98 [D-Real: 946.71 D-Fake: -1094.58]
Epoch  62 | ET 99.11 min | Avg Losses >> G/D 1158.20/-130.87 [D-Real: 1009.98 D-Fake: -1158.20]
Epoch  63 | ET 100.70 min | Avg Losses >> G/D 1015.38/-102.59 [D-Real: 907.90 D-Fake: -1015.38]
Epoch  64 | ET 102.30 min | Avg Losses >> G/D 1452.54/-183.59 [D-Real: 1250.75 D-Fake: -1452.54]
Epoch  65 | ET 103.90 min | Avg Losses >> G/D 1532.75/-176.89 [D-Real: 1327.26 D-Fake: -1532.75]
Epoch  66 | ET 105.50 min | Avg Losses >> G/D 1372.27/-184.02 [D-Real: 1173.63 D-Fake: -1372.27]
Epoch  67 | ET 107.10 min | Avg Losses >> G/D 1476.47/-151.54 [D-Real: 1286.28 D-Fake: -1476.47]
Epoch  68 | ET 108.70 min | Avg Losses >> G/D 1337.81/-165.44 [D-Real: 1155.44 D-Fake: -1337.81]
Epoch  69 | ET 110.30 min | Avg Losses >> G/D 1868.09/-265.95 [D-Real: 1564.99 D-Fake: -1868.09]
Epoch  70 | ET 111.90 min | Avg Losses >> G/D 1815.22/155.57 [D-Real: 1589.90 D-Fake: -1815.22]
Epoch  71 | ET 113.50 min | Avg Losses >> G/D 2016.40/-162.41 [D-Real: 1828.58 D-Fake: -2016.40]
Epoch  72 | ET 115.10 min | Avg Losses >> G/D 2118.12/-280.26 [D-Real: 1827.30 D-Fake: -2118.12]
Epoch  73 | ET 116.69 min | Avg Losses >> G/D 2143.48/-318.27 [D-Real: 1818.62 D-Fake: -2143.48]
Epoch  74 | ET 118.29 min | Avg Losses >> G/D 2188.90/-349.97 [D-Real: 1830.72 D-Fake: -2188.90]
Epoch  75 | ET 119.90 min | Avg Losses >> G/D 2362.95/-404.66 [D-Real: 1941.11 D-Fake: -2362.95]
Epoch  76 | ET 121.50 min | Avg Losses >> G/D 2443.07/-427.19 [D-Real: 2004.03 D-Fake: -2443.07]
Epoch  77 | ET 123.10 min | Avg Losses >> G/D 2404.13/-391.35 [D-Real: 1998.95 D-Fake: -2404.13]
Epoch  78 | ET 124.70 min | Avg Losses >> G/D 2562.10/-411.77 [D-Real: 2129.07 D-Fake: -2562.10]
Epoch  79 | ET 126.30 min | Avg Losses >> G/D 2826.90/-540.18 [D-Real: 2257.65 D-Fake: -2826.90]
Epoch  80 | ET 127.90 min | Avg Losses >> G/D 2765.97/-489.78 [D-Real: 2263.39 D-Fake: -2765.97]
Epoch  81 | ET 129.50 min | Avg Losses >> G/D 2775.33/-561.67 [D-Real: 2196.64 D-Fake: -2775.33]
Epoch  82 | ET 131.10 min | Avg Losses >> G/D 3329.70/-599.79 [D-Real: 2616.69 D-Fake: -3329.70]
Epoch  83 | ET 132.70 min | Avg Losses >> G/D 3475.59/-675.97 [D-Real: 2752.62 D-Fake: -3475.59]
Epoch  84 | ET 134.29 min | Avg Losses >> G/D 3634.00/-782.03 [D-Real: 2835.74 D-Fake: -3634.00]
Epoch  85 | ET 135.90 min | Avg Losses >> G/D 3804.03/-801.69 [D-Real: 2961.38 D-Fake: -3804.03]
Epoch  86 | ET 137.50 min | Avg Losses >> G/D 3938.59/-868.39 [D-Real: 3045.14 D-Fake: -3938.59]
Epoch  87 | ET 139.10 min | Avg Losses >> G/D 3951.56/-853.02 [D-Real: 3086.04 D-Fake: -3951.56]
Epoch  88 | ET 140.70 min | Avg Losses >> G/D 3248.74/-723.00 [D-Real: 2495.87 D-Fake: -3248.74]
Epoch  89 | ET 142.29 min | Avg Losses >> G/D 4644.86/-1046.76 [D-Real: 3557.25 D-Fake: -4644.86]
Epoch  90 | ET 143.89 min | Avg Losses >> G/D 4885.19/-1035.63 [D-Real: 3767.10 D-Fake: -4885.19]
Epoch  91 | ET 145.49 min | Avg Losses >> G/D 3802.31/-880.71 [D-Real: 2886.02 D-Fake: -3802.31]
Epoch  92 | ET 147.09 min | Avg Losses >> G/D 4992.27/-1107.25 [D-Real: 3849.27 D-Fake: -4992.27]
Epoch  93 | ET 148.70 min | Avg Losses >> G/D 5438.34/-1285.65 [D-Real: 4107.10 D-Fake: -5438.34]
Epoch  94 | ET 150.30 min | Avg Losses >> G/D 5705.12/-1254.32 [D-Real: 4341.58 D-Fake: -5705.12]
Epoch  95 | ET 151.90 min | Avg Losses >> G/D 6116.86/-1472.97 [D-Real: 4579.57 D-Fake: -6116.86]
Epoch  96 | ET 153.50 min | Avg Losses >> G/D 6402.22/-1595.38 [D-Real: 4765.93 D-Fake: -6402.22]
Epoch  97 | ET 155.10 min | Avg Losses >> G/D 6487.41/-1598.25 [D-Real: 4807.29 D-Fake: -6487.41]
Epoch  98 | ET 156.70 min | Avg Losses >> G/D 5330.31/-1202.23 [D-Real: 4032.31 D-Fake: -5330.31]
Epoch  99 | ET 158.30 min | Avg Losses >> G/D 6946.26/-1631.31 [D-Real: 5270.38 D-Fake: -6946.26]
Epoch 100 | ET 159.90 min | Avg Losses >> G/D 7301.76/-1722.93 [D-Real: 5541.48 D-Fake: -7301.76]
#import pickle
#pickle.dump({'all_losses':all_losses, 
#             'samples':epoch_samples}, 
#            open('/content/drive/My Drive/Colab Notebooks/PyML-3rd-edition/ch17-WDCGAN-learning.pkl', 'wb'))

#gen_model.save('/content/drive/My Drive/Colab Notebooks/PyML-3rd-edition/ch17-WDCGAN-gan_gen.h5')
#disc_model.save('/content/drive/My Drive/Colab Notebooks/PyML-3rd-edition/ch17-WDCGAN-gan_disc.h5')
import itertools


fig = plt.figure(figsize=(8, 6))

## Plotting the losses
ax = fig.add_subplot(1, 1, 1)
g_losses = [item[0] for item in itertools.chain(*all_losses)]
d_losses = [item[1] for item in itertools.chain(*all_losses)]
plt.plot(g_losses, label='Generator loss', alpha=0.95)
plt.plot(d_losses, label='Discriminator loss', alpha=0.95)
plt.legend(fontsize=20)
ax.set_xlabel('Iteration', size=15)
ax.set_ylabel('Loss', size=15)

epochs = np.arange(1, 101)
epoch2iter = lambda e: e*len(all_losses[-1])
epoch_ticks = [1, 20, 40, 60, 80, 100]
newpos   = [epoch2iter(e) for e in epoch_ticks]
ax2 = ax.twiny()
ax2.set_xticks(newpos)
ax2.set_xticklabels(epoch_ticks)
ax2.xaxis.set_ticks_position('bottom')
ax2.xaxis.set_label_position('bottom')
ax2.spines['bottom'].set_position(('outward', 60))
ax2.set_xlabel('Epoch', size=15)
ax2.set_xlim(ax.get_xlim())
ax.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='major', labelsize=15)

#plt.savefig('images/ch17-wdcgan-learning-curve.pdf')
plt.show()

在这里插入图片描述

selected_epochs = [1, 2, 4, 10, 50, 100]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
    for j in range(5):
        ax = fig.add_subplot(6, 5, i*5+j+1)
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:
            ax.text(
                -0.06, 0.5, 'Epoch {}'.format(e),
                rotation=90, size=18, color='red',
                horizontalalignment='right',
                verticalalignment='center', 
                transform=ax.transAxes)
        
        image = epoch_samples[e-1][j]
        ax.imshow(image, cmap='gray_r')
    
#plt.savefig('images/ch17-wdcgan-samples.pdf')
plt.show()

在这里插入图片描述

模式碰撞Mode collapse

由于GAN模型的对抗性,所以训练GAN模型非常困难。训练GAN失败的一个常见原因是生成器陷入了一个小的子空间,以至于它学习生成的都是类似的样本

几何表示如下:

Image(filename='images/17_16.png', width=600)

在这里插入图片描述

除了之前看到的梯度消失和梯度爆炸问题,实际上还有一些因素也会使得GAN模型的训练变得困难。

这里列举出一些文献中提到的训练技巧:

mini-batch discrimination

将由真实样本或者生成样本所组成的批次分别提供给判别器;

特征匹配

在特征匹配中,对生成器的目标函数做了一点修改,增加了一个额外项,从而使得基于判别器的中间表示(特征映射)的原始图像和生成图像之间的差异最小化;




  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-28 07:45:54  更:2021-07-28 07:47:21 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 23:49:09-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计