IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于生成对抗网络(GAN)生成手写数字(MNIST) -> 正文阅读

[人工智能]基于生成对抗网络(GAN)生成手写数字(MNIST)

GAN生成手写数字识别

生成对抗网络介绍:

生成对抗网络就是两个网络进行相互对抗,相互进行的过程。GAN的主要部分为生成器Generator和判别器Discriminator。
生成器:输入一个随机向量,生成一个图片。希望生成的图片越像真的越好。
判别器:输出一个图片,判别图片的真伪。希望生成的图片判别为假,数据集中的图片判别为真。

生成器:

     def Generator(self):
        #构建模型
        gen=models.Sequential()
        gen.add(layers.Dense(256,input_dim=self.dims))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(512))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(1024))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(self.row*self.col*self.channel,activation='sigmoid'))
        gen.add(Reshape(self.shape))

        #输出模型
        gen.summary()

        #生成随机噪声
        noise=Input(shape=(self.dims,))
        #生成图片
        img=gen(noise)

        return Model(noise,img)

判别器:

    def Discrimation(self):
        #构建模型
        model=models.Sequential()
        model.add(layers.Flatten(input_shape=self.shape))
        model.add(layers.Dense(512))
        model.add(layers.LeakyReLU(alpha=0.2))
        model.add(layers.Dense(256))
        model.add(layers.LeakyReLU(alpha=0.2))
        model.add(layers.Dense(1,activation='sigmoid'))
        #输出模型
        model.summary()
        #输入
        input=Input(shape=(self.shape))
        #输出
        output=model(input)

        return Model(input,output)

训练代码:

#coding=utf-8
from keras import layers,datasets,models,optimizers
from keras.layers import Reshape,Input
from keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.datasets import mnist

class GAN():
    def __init__(self):

        self.row=28
        self.col=28
        self.channel=1
        self.shape=(self.row,self.col,self.channel)
        self.dims=100

        optimizer=optimizers.Adam(0.0002, 0.5)

        # 判别器
        self.Dis = self.Discrimation()
        self.Dis.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])

        # 生成器
        self.Gen = self.Generator()
        # 随机噪声
        z = Input(shape=(self.dims,))
        img = self.Gen(z)
        # 判别器不训练
        self.Dis.trainable = False
        # 对假图片进行预测
        validity = self.Dis(img)
        #返回模型
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def Generator(self):
        #构建模型
        gen=models.Sequential()
        gen.add(layers.Dense(256,input_dim=self.dims))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(512))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(1024))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(self.row*self.col*self.channel,activation='sigmoid'))
        gen.add(Reshape(self.shape))

        #输出模型
        gen.summary()

        #生成随机噪声
        noise=Input(shape=(self.dims,))
        #生成图片
        img=gen(noise)

        return Model(noise,img)

    def Discrimation(self):
        #构建模型
        model=models.Sequential()
        model.add(layers.Flatten(input_shape=self.shape))
        model.add(layers.Dense(512))
        model.add(layers.LeakyReLU(alpha=0.2))
        model.add(layers.Dense(256))
        model.add(layers.LeakyReLU(alpha=0.2))
        model.add(layers.Dense(1,activation='sigmoid'))
        #输出模型
        model.summary()
        #输入
        input=Input(shape=(self.shape))
        #输出
        output=model(input)

        return Model(input,output)

    def train(self,epochs,batch_size,save_size):
        #加载数据集
        (train_x,_),(_,_)=mnist.load_data()
        #数据集归一化
        train_x=train_x/255.
        train_x=np.expand_dims(train_x,axis=3)

        #生成标签
        real=np.ones((batch_size,1))
        fake=np.zeros((batch_size,1))

        for epoch in range(epochs):

            # 随机选取一批图片
            idx = np.random.randint(0, train_x.shape[0], batch_size)
            imgs = train_x[idx]

            # 生成一批噪声
            noise = np.random.normal(0, 1, (batch_size, self.dims))

            gen_image=self.Gen(noise)

            d_loss_real=self.Dis.train_on_batch(imgs,real)
            d_loss_fake=self.Dis.train_on_batch(gen_image,fake)
            d_loss=0.5*np.add(d_loss_real,d_loss_fake)

            noise=np.random.normal(0,1,(batch_size,self.dims))

            g_loss=self.combined.train_on_batch(noise,real)

            print('Epoch:{} ,D_loss:{}, D_acc:{} ,G_loss:{} '.format(epoch,d_loss[0],100*d_loss[1],g_loss))

            if epoch%save_size==0:
                self.save_image(epoch)

    def save_image(self,epoch):
        r,c=5,5

        noise=np.random.normal(0,1,(r*c,self.dims))
        gen_img=self.Gen.predict(noise)
        cnt=0

        fig,axs=plt.subplots(r,c)
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_img[cnt,:,:,0],cmap='gray')
                axs[i, j].axis('off')
                cnt+=1

        fig.savefig("image/{}.png".format(epoch))
        plt.close()

    def save(self):
        self.Dis.save('model/Discrimator.h5')
        self.Gen.save('model/Generator.h5')

if __name__ == '__main__':
    gan=GAN()
    gan.train(20000,512,100)
    gan.save()

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-01 11:55:38  更:2021-09-01 11:57:53 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 16:32:17-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码