| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> 教你编写第一个生成式对抗网络GAN -> 正文阅读 |
|
[人工智能]教你编写第一个生成式对抗网络GAN |
朋友们,如需转载请标明出处:https://blog.csdn.net/jiangjunshow 前面讲解了那么多GAN的基础知识,我们已经比较深入地了解GAN了,但如果不动手将上面的理论知识融入到实战中,你依旧无法内化上面的内容,所以接着就通过TensorFlow来实现一个朴素GAN。(文章中使用的是Tensorflow 1.x版本的语法) 我们主要是创建一个最简单的GAN,然后训练它,使它可以生成与真实图片一样的手写数字图片。下面直接进行代码的编写。 (1)导入第三方库。
我们使用TensorFlow来实现GAN的网络架构,并对构建的GAN进行训练;使用numpy来生成随机噪声,用于给生成器生成输入数据;使用pickle来持久化地保存变量;最后使用matplotlib来可视化GAN训练时两个网络结构损失的变化以及GAN生成的图片。 (2)因为是要训练GAN生成MNIST手写数据集中的图片,需要读入MNIST数据集中的真实图片作为训练判别器D的真实数据,TensorFlow提供了处理MNIST的方法,可以使用它读入MNIST数据。
读入MNIST图片后,每一张图片都由一个一维矩阵表示。
PS:TensorFlow在1.9版本后,input_data.read_data_sets方法不会自动下载,如果本地没有MNIST数据集,就会报错,所以我们必须事先将它下载好。 接着定义用于接收输入的方法,使用TensorFlow的placeholder占位符来获得输入的数据。
然后就可以实现生成器和判别器了,先来看生成器,代码如下。
可以发现生成器的网络结构非常简单,只是一个具有单隐藏层的神经网络,其整体结构为输入层→隐藏层→输出层,一开始只是编写最简单的GAN,在后面的高级内容中,生成器和判别器的结构会复杂一些。 简单解释一下上面的代码,首先使用tf.variable_scope创建了一个名为generator的空间,主要目的是实现在该空间中,变量可以被重复使用且方便区分不同卷积层之间的组件。 接着使用tf.layers下的dense方法将输入层和隐藏层进行全连接。tf.layers模块提供了很多封装层次较高的方法,使用这些方法,我们可以更加轻松地构建相应的神经网络结构。这里使用dense方法,其作用就是实现全连接。 我们选择Leaky ReLU作为隐藏层的激活函数,使用tf.maximum方法返回通过Leaky ReLU激活后较大的值。 然后使用tf.layers的dropout方法,其做法就是按一定的概率随机弃用神经网络中的网络单元(即将该网络单元的参数置0),防止发生过拟合现象,dropout只能在训练时使用,在测试时不能使用。最后再通过dense方法,实现隐藏层与输出层全连接,并使用Tanh作为输出层的激活函数(试验中用Tanh作为激活函数生成器效果更好),Tanh函数的输出范围是?1~1,即表示生成图片的像素范围是?1~1,但MNIST数据集中真实图片的像素范围是0~1,所以在训练时,要调整真实图片的像素范围,让其与生成图片一致。 Leakey ReLU函数是ReLU函数的变种,与ReLU函数的不同之处在于,ReLU将所有的负值都设为零,而Leakey ReLU则给负值乘以一个斜率。 接着看判别器的代码。
判别器的实现代码与生成器没有太大差别,稍有不同的地方就是,判别器的输出层只有一个网络单元且使用sigmoid作为输出层的激活函数,sigmoid函数输出值的范围是0~1。 生成器和判别器编写完成后,接着就来编写具体的计算图,首先做一些初始化工作,如定义需要的变量、清空default graph计算图。
然后我们通过get_inputs方法获得真实图片的输入和噪声输入,并传入生成器和判别器进行训练,当然,现在只是构建GAN整个网络的训练结构。
上面的代码将噪声、生成器隐藏层节点数、真实图片大小传入生成器,传入真实图片的大小是因为要求生成器可以生成与真实图片大小一样的图片。 判别器一开始先传入真实图片和判别器隐藏层节点,为真实图片打分,接着再用相同的参数训练生成图片,为生成图片打分。 训练逻辑构建完成,接着就定义生成器和判别器的损失。先回忆一下前面对损失的描述,判别器的损失由判别器给真实图片打分与其期望分数的差距、判别器给生成图片打分与其期望分数的差距两部分构成。这里定义最高分为1、最低分为0,也就是说判别器希望给真实图片打1分,给生成图片打0分。生成器的损失实质上是生成图片与真实图片概率分布上的差距,这里将其转换为,生成器期望判别器给自己的生成图片打多少分与实际上判别器给生成图片打多少分的差距。
计算损失时使用tf.nn.sigmoid_cross_entropy_with_logits方法,它对传入的logits参数先使用sigmoid函数计算,然后再计算它们的cross entropy交叉熵损失,同时该方法优化了cross entropy的计算方式,使得结果不会溢出。从方法的名字就可以直观地看出它的作用。 损失定义好后,要做的就是最小化这个损失。
要最小化损失,先要获得对应网络结构中的参数,也就是生成器和判别器的变量,这是最小化损失时要修改的对象。这里使用AdamOptimizer方法来最小化损失,其内部实现了Adam算法,该算法基于梯度下降算法,但它可以动态地调整每个参数的学习速率。 至此整个计算结果大致定义完成,接着开始实现具体的训练逻辑,先初始化一些与训练有关的变量。
编写训练具体代码。
一开始是创建Session对象,然后使用双层for循环进行GAN的训练,第一层表示要训练多少轮,第二层表示每一轮训练时,要取的样本量,因为一口气训练完所有的真实图片效率会比较低,一般的做法是将其分割成多组,然后进行多轮训练,这里64张为一组。 接着就是读入一组真实数据,因为生成器使用Tanh作为输出层的激活函数,导致生成图片的像素范围是?1~1,所以这里也简单调整一下真实图片的像素访问,将其从0~1变为?1~1,然后使用numpy的uniform方法生成?1~1之间的随机噪声。准备好真实数据和噪声数据后,就可以丢入生成器和判别器了,数据会按我们之前设计好的计算图运行,值得注意的是,要先训练判别器,再训练生成器。 当本轮将所有的真实图片都训练了一遍后,计算一下本轮生成器和判别器的损失,并将损失记录起来,方便后面可视化GAN训练过程中损失的变化。为了直观地感受GAN训练时生成器的变化,每一轮GAN训练完都用此时的生成器生成一组图片并保存起来。训练逻辑编写完后,就可以让训练代码运行起来,输出如下内容。 训练轮数 1/500… 判别器总损失: 0.0190(真实图片损失: 0.0017 + 虚假图片损失: 0.0173)… 生成器损失: 4.1502 训练轮数 2/500… 判别器总损失: 1.0480(真实图片损失: 0.3772 + 虚假图片损失: 0.6708)… 生成器损失: 3.1548 训练轮数 3/500… 判别器总损失: 0.5315(真实图片损失: 0.3580 + 虚假图片损失: 0.1736)… 生成器损失: 2.8828 训练轮数 4/500… 判别器总损失: 2.9703(真实图片损失: 1.5434 + 虚假图片损失: 1.4268)… 生成器损失: 0.7844 训练轮数 5/500… 判别器总损失: 1.0076(真实图片损失: 0.5763 + 虚假图片损失: 0.4314)… 生成器损失: 1.8176 训练轮数 6/500… 判别器总损失: 0.7265(真实图片损失: 0.4558 + 虚假图片损失: 0.2707)… 生成器损失: 2.9691 训练轮数 7/500… 判别器总损失: 1.5635(真实图片损失: 0.8336 + 虚假图片损失: 0.7299)… 生成器损失: 2.1342 整个训练过程会花费30~40分钟。 |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 | -2025/1/11 14:22:55- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |