IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 教你编写第一个生成式对抗网络GAN -> 正文阅读

[人工智能]教你编写第一个生成式对抗网络GAN

朋友们,如需转载请标明出处:https://blog.csdn.net/jiangjunshow

前面讲解了那么多GAN的基础知识,我们已经比较深入地了解GAN了,但如果不动手将上面的理论知识融入到实战中,你依旧无法内化上面的内容,所以接着就通过TensorFlow来实现一个朴素GAN。(文章中使用的是Tensorflow 1.x版本的语法)

我们主要是创建一个最简单的GAN,然后训练它,使它可以生成与真实图片一样的手写数字图片。下面直接进行代码的编写。

(1)导入第三方库。

 import tensorflow as tf

  import numpy as np

  import pickle

  import matplotlib.pyplot as plt

我们使用TensorFlow来实现GAN的网络架构,并对构建的GAN进行训练;使用numpy来生成随机噪声,用于给生成器生成输入数据;使用pickle来持久化地保存变量;最后使用matplotlib来可视化GAN训练时两个网络结构损失的变化以及GAN生成的图片。

(2)因为是要训练GAN生成MNIST手写数据集中的图片,需要读入MNIST数据集中的真实图片作为训练判别器D的真实数据,TensorFlow提供了处理MNIST的方法,可以使用它读入MNIST数据。

 from tensorflow.examples.tutorials.mnist import input_data

  # 读入MNIST数据

  mnist = input_data.read_data_sets('./data/MNIST_data')

  img = mnist.train.images[500]   

#以灰度图的形式读入

  plt.imshow(img.reshape((28, 28)), cmap='Greys_r')

  plt.show()

读入MNIST图片后,每一张图片都由一个一维矩阵表示。

  print(type(img))

  print(img.shape)
输出如下。
  <class 'numpy.ndarray'>

  (784,)

PS:TensorFlow在1.9版本后,input_data.read_data_sets方法不会自动下载,如果本地没有MNIST数据集,就会报错,所以我们必须事先将它下载好。

接着定义用于接收输入的方法,使用TensorFlow的placeholder占位符来获得输入的数据。

  def get_inputs(real_size, noise_size):

      real_img = tf.placeholder(tf.float32, [None, real_size], name='real_img')

      noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img')

      return real_img, noise_img

然后就可以实现生成器和判别器了,先来看生成器,代码如下。

  def generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):

  '''

  生成器

   :paramnoise_img: 生成器生成的噪声图片

:paramn_units: 隐藏层单元数

   :paramout_dim: 生成器输出的tensor的size,应该是32×32=784

   :param reuse: 是否重用空间

   :param alpha: leakeyReLU系数

   :return:

   '''

  with tf.variable_scope("generator", reuse=reuse):

            #全连接

            hidden1 = tf.layers.dense(noise_img, n_units)

            #返回最大值

            hidden1 = tf.maximum(alpha * hidden1, hidden1)

            hidden1 = tf.layers.dropout(hidden1, rate=0.2, training=True)

            #dense:全连接

            logits = tf.layers.dense(hidden1, out_dim)

            outputs = tf.tanh(logits)

            return logits, outputs

可以发现生成器的网络结构非常简单,只是一个具有单隐藏层的神经网络,其整体结构为输入层→隐藏层→输出层,一开始只是编写最简单的GAN,在后面的高级内容中,生成器和判别器的结构会复杂一些。

简单解释一下上面的代码,首先使用tf.variable_scope创建了一个名为generator的空间,主要目的是实现在该空间中,变量可以被重复使用且方便区分不同卷积层之间的组件。

接着使用tf.layers下的dense方法将输入层和隐藏层进行全连接。tf.layers模块提供了很多封装层次较高的方法,使用这些方法,我们可以更加轻松地构建相应的神经网络结构。这里使用dense方法,其作用就是实现全连接。

我们选择Leaky ReLU作为隐藏层的激活函数,使用tf.maximum方法返回通过Leaky ReLU激活后较大的值。

然后使用tf.layers的dropout方法,其做法就是按一定的概率随机弃用神经网络中的网络单元(即将该网络单元的参数置0),防止发生过拟合现象,dropout只能在训练时使用,在测试时不能使用。最后再通过dense方法,实现隐藏层与输出层全连接,并使用Tanh作为输出层的激活函数(试验中用Tanh作为激活函数生成器效果更好),Tanh函数的输出范围是?1~1,即表示生成图片的像素范围是?1~1,但MNIST数据集中真实图片的像素范围是0~1,所以在训练时,要调整真实图片的像素范围,让其与生成图片一致。

Leakey ReLU函数是ReLU函数的变种,与ReLU函数的不同之处在于,ReLU将所有的负值都设为零,而Leakey ReLU则给负值乘以一个斜率。

接着看判别器的代码。

  def discirminator(img, n_units, reuse=False, alpha=0.01):

      '''

  判别器

     	:paramimg: 图片(真实图片/生成图片)

  		:paramn_units:

  		:param reuse:

  		:param alpha:

  		:return:

     	'''

       with tf.variable_scope('discriminator', reuse=reuse):

            hidden1 = tf.layers.dense(img, n_units)

            hidden1 = tf.maximum(alpha * hidden1, hidden1)

            logits = tf.layers.dense(hidden1, 1)

            outputs = tf.sigmoid(logits)

            return logits, outputs

判别器的实现代码与生成器没有太大差别,稍有不同的地方就是,判别器的输出层只有一个网络单元且使用sigmoid作为输出层的激活函数,sigmoid函数输出值的范围是0~1。

生成器和判别器编写完成后,接着就来编写具体的计算图,首先做一些初始化工作,如定义需要的变量、清空default graph计算图。

img_size = mnist.train.images[0].shape[0]#真实图片大小

  noise_size = 100 #噪声,Generator的初始输入

  g_units = 128#生成器隐藏层参数

  d_units = 128

  alpha = 0.01 #leaky ReLU参数

  learning_rate = 0.001 #学习速率

  smooth = 0.1 #标签平滑

  # 重置default graph计算图以及nodes节点

  tf.reset_default_graph()

然后我们通过get_inputs方法获得真实图片的输入和噪声输入,并传入生成器和判别器进行训练,当然,现在只是构建GAN整个网络的训练结构。


#生成器

g_logits, g_outputs = generator(noise_img, g_units, img_size)

#判别器

d_logits_real, d_outputs_real = discirminator(real_img, d_units)

# 传入生成图片,为其打分

d_logits_fake, d_outputs_fake = discirminator(g_outputs, d_units, reuse=True)

上面的代码将噪声、生成器隐藏层节点数、真实图片大小传入生成器,传入真实图片的大小是因为要求生成器可以生成与真实图片大小一样的图片。

判别器一开始先传入真实图片和判别器隐藏层节点,为真实图片打分,接着再用相同的参数训练生成图片,为生成图片打分。

训练逻辑构建完成,接着就定义生成器和判别器的损失。先回忆一下前面对损失的描述,判别器的损失由判别器给真实图片打分与其期望分数的差距、判别器给生成图片打分与其期望分数的差距两部分构成。这里定义最高分为1、最低分为0,也就是说判别器希望给真实图片打1分,给生成图片打0分。生成器的损失实质上是生成图片与真实图片概率分布上的差距,这里将其转换为,生成器期望判别器给自己的生成图片打多少分与实际上判别器给生成图片打多少分的差距。

  d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(

      logits=d_logits_real, labels=tf.ones_like(d_logits_real))*(1-smooth))

  d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(

      logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)

  ))

  #判别器总损失

  d_loss = tf.add(d_loss_real, d_loss_fake)

  g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(

      logits=d_logits_fake, labels=tf.ones_like(d_logits_fake))*(1-smooth))

计算损失时使用tf.nn.sigmoid_cross_entropy_with_logits方法,它对传入的logits参数先使用sigmoid函数计算,然后再计算它们的cross entropy交叉熵损失,同时该方法优化了cross entropy的计算方式,使得结果不会溢出。从方法的名字就可以直观地看出它的作用。

损失定义好后,要做的就是最小化这个损失。

  # generator中的tensor

  g_vars = [var for var in train_vars if var.name.startswith("generator")]

  # discriminator中的tensor

  d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

  #AdamOptimizer优化损失

  d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)

  g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

要最小化损失,先要获得对应网络结构中的参数,也就是生成器和判别器的变量,这是最小化损失时要修改的对象。这里使用AdamOptimizer方法来最小化损失,其内部实现了Adam算法,该算法基于梯度下降算法,但它可以动态地调整每个参数的学习速率。

至此整个计算结果大致定义完成,接着开始实现具体的训练逻辑,先初始化一些与训练有关的变量。

  batch_size = 64 #每一轮训练数量

  epochs = 500 #训练迭代轮数

  n_sample = 25 #抽取样本数

  samples = [] #存储测试样例

  losses = [] #存储loss

  #保存生成器变量

  saver = tf.train.Saver(var_list=g_vars)

编写训练具体代码。

  with tf.Session() as sess:

     	# 初始化模型的参数

 		 sess.run(tf.global_variables_initializer())

      for e in range(epochs):

            for batch_i in range(mnist.train.num_examples // batch_size):

                 batch = mnist.train.next_batch(batch_size)

                 #28 × 28 = 784

                 batch_images = batch[0].reshape((batch_size, 784))

                 # 对图像像素进行scale,这是因为Tanh输出的结果介于(-1,1)之间,real和fake图片共享discriminator的参数

                 batch_images = batch_images * 2 -1

                 #生成噪声图片

                 batch_noise = np.random.uniform(-1,1,size=(batch_size, noise_size))

                 #先训练判别器,再训练生成器

                 _= sess.run(d_train_opt, feed_dict={real_img: batch_images, noise_img:batch_noise})

                 _= sess.run(g_train_opt, feed_dict={noise_img:batch_noise})

            #每一轮训练完后,都计算一下loss

            train_loss_d = sess.run(d_loss, feed_dict={real_img:batch_images, noise_img:batch_noise})

            # 判别器训练时真实图片的损失

            train_loss_d_real = sess.run(d_loss_real, feed_dict={real_img:batch_images,noise_img:batch_noise})

            # 判别器训练时生成图片的损失

            train_loss_d_fake = sess.run(d_loss_fake, feed_dict={real_img:batch_images,noise_img:batch_noise})

            # 生成器损失

            train_loss_g = sess.run(g_loss, feed_dict= {noise_img: batch_noise})

            print("训练轮数 {}/{}...".format(e + 1, epochs),

            "判别器总损失: {:.4f}(真实图片损失: {:.4f} + 虚假图片损失: {:.4f})...".format(train_loss_d,train_loss_d_real,train_loss_d_fake),"生成器损失: {:.4f}".format(train_loss_g))

            # 记录各类loss值

            losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))

            # 抽取样本后期进行观察

            sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))

            #生成样本,保存起来后期观察

            gen_samples = sess.run(generator(noise_img, g_units, img_size, reuse=True),

            feed_dict={noise_img:sample_noise})

            samples.append(gen_samples)

            # 存储checkpoints

            saver.save(sess, './data/generator.ckpt')

            with open('./data/train_samples.pkl', 'wb') as f:

            pickle.dump(samples,f)

一开始是创建Session对象,然后使用双层for循环进行GAN的训练,第一层表示要训练多少轮,第二层表示每一轮训练时,要取的样本量,因为一口气训练完所有的真实图片效率会比较低,一般的做法是将其分割成多组,然后进行多轮训练,这里64张为一组。

接着就是读入一组真实数据,因为生成器使用Tanh作为输出层的激活函数,导致生成图片的像素范围是?1~1,所以这里也简单调整一下真实图片的像素访问,将其从0~1变为?1~1,然后使用numpy的uniform方法生成?1~1之间的随机噪声。准备好真实数据和噪声数据后,就可以丢入生成器和判别器了,数据会按我们之前设计好的计算图运行,值得注意的是,要先训练判别器,再训练生成器。

当本轮将所有的真实图片都训练了一遍后,计算一下本轮生成器和判别器的损失,并将损失记录起来,方便后面可视化GAN训练过程中损失的变化。为了直观地感受GAN训练时生成器的变化,每一轮GAN训练完都用此时的生成器生成一组图片并保存起来。训练逻辑编写完后,就可以让训练代码运行起来,输出如下内容。

训练轮数 1/500… 判别器总损失: 0.0190(真实图片损失: 0.0017 + 虚假图片损失: 0.0173)…

生成器损失: 4.1502

训练轮数 2/500… 判别器总损失: 1.0480(真实图片损失: 0.3772 + 虚假图片损失: 0.6708)…

生成器损失: 3.1548

训练轮数 3/500… 判别器总损失: 0.5315(真实图片损失: 0.3580 + 虚假图片损失: 0.1736)…

生成器损失: 2.8828

训练轮数 4/500… 判别器总损失: 2.9703(真实图片损失: 1.5434 + 虚假图片损失: 1.4268)…

生成器损失: 0.7844

训练轮数 5/500… 判别器总损失: 1.0076(真实图片损失: 0.5763 + 虚假图片损失: 0.4314)…

生成器损失: 1.8176

训练轮数 6/500… 判别器总损失: 0.7265(真实图片损失: 0.4558 + 虚假图片损失: 0.2707)…

生成器损失: 2.9691

训练轮数 7/500… 判别器总损失: 1.5635(真实图片损失: 0.8336 + 虚假图片损失: 0.7299)…

生成器损失: 2.1342

整个训练过程会花费30~40分钟。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-09 16:16:59  更:2021-10-09 16:17:03 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/16 4:29:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码