GAN生成手写数字识别
生成对抗网络介绍:
生成对抗网络就是两个网络进行相互对抗,相互进行的过程。GAN的主要部分为生成器Generator和判别器Discriminator。 生成器:输入一个随机向量,生成一个图片。希望生成的图片越像真的越好。 判别器:输出一个图片,判别图片的真伪。希望生成的图片判别为假,数据集中的图片判别为真。
生成器:
def Generator(self):
gen=models.Sequential()
gen.add(layers.Dense(256,input_dim=self.dims))
gen.add(layers.LeakyReLU(alpha=0.2))
gen.add(layers.BatchNormalization(momentum=0.8))
gen.add(layers.Dense(512))
gen.add(layers.LeakyReLU(alpha=0.2))
gen.add(layers.BatchNormalization(momentum=0.8))
gen.add(layers.Dense(1024))
gen.add(layers.LeakyReLU(alpha=0.2))
gen.add(layers.BatchNormalization(momentum=0.8))
gen.add(layers.Dense(self.row*self.col*self.channel,activation='sigmoid'))
gen.add(Reshape(self.shape))
gen.summary()
noise=Input(shape=(self.dims,))
img=gen(noise)
return Model(noise,img)
判别器:
def Discrimation(self):
model=models.Sequential()
model.add(layers.Flatten(input_shape=self.shape))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(256))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(1,activation='sigmoid'))
model.summary()
input=Input(shape=(self.shape))
output=model(input)
return Model(input,output)
训练代码:
from keras import layers,datasets,models,optimizers
from keras.layers import Reshape,Input
from keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.datasets import mnist
class GAN():
def __init__(self):
self.row=28
self.col=28
self.channel=1
self.shape=(self.row,self.col,self.channel)
self.dims=100
optimizer=optimizers.Adam(0.0002, 0.5)
self.Dis = self.Discrimation()
self.Dis.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])
self.Gen = self.Generator()
z = Input(shape=(self.dims,))
img = self.Gen(z)
self.Dis.trainable = False
validity = self.Dis(img)
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def Generator(self):
gen=models.Sequential()
gen.add(layers.Dense(256,input_dim=self.dims))
gen.add(layers.LeakyReLU(alpha=0.2))
gen.add(layers.BatchNormalization(momentum=0.8))
gen.add(layers.Dense(512))
gen.add(layers.LeakyReLU(alpha=0.2))
gen.add(layers.BatchNormalization(momentum=0.8))
gen.add(layers.Dense(1024))
gen.add(layers.LeakyReLU(alpha=0.2))
gen.add(layers.BatchNormalization(momentum=0.8))
gen.add(layers.Dense(self.row*self.col*self.channel,activation='sigmoid'))
gen.add(Reshape(self.shape))
gen.summary()
noise=Input(shape=(self.dims,))
img=gen(noise)
return Model(noise,img)
def Discrimation(self):
model=models.Sequential()
model.add(layers.Flatten(input_shape=self.shape))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(256))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(1,activation='sigmoid'))
model.summary()
input=Input(shape=(self.shape))
output=model(input)
return Model(input,output)
def train(self,epochs,batch_size,save_size):
(train_x,_),(_,_)=mnist.load_data()
train_x=train_x/255.
train_x=np.expand_dims(train_x,axis=3)
real=np.ones((batch_size,1))
fake=np.zeros((batch_size,1))
for epoch in range(epochs):
idx = np.random.randint(0, train_x.shape[0], batch_size)
imgs = train_x[idx]
noise = np.random.normal(0, 1, (batch_size, self.dims))
gen_image=self.Gen(noise)
d_loss_real=self.Dis.train_on_batch(imgs,real)
d_loss_fake=self.Dis.train_on_batch(gen_image,fake)
d_loss=0.5*np.add(d_loss_real,d_loss_fake)
noise=np.random.normal(0,1,(batch_size,self.dims))
g_loss=self.combined.train_on_batch(noise,real)
print('Epoch:{} ,D_loss:{}, D_acc:{} ,G_loss:{} '.format(epoch,d_loss[0],100*d_loss[1],g_loss))
if epoch%save_size==0:
self.save_image(epoch)
def save_image(self,epoch):
r,c=5,5
noise=np.random.normal(0,1,(r*c,self.dims))
gen_img=self.Gen.predict(noise)
cnt=0
fig,axs=plt.subplots(r,c)
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_img[cnt,:,:,0],cmap='gray')
axs[i, j].axis('off')
cnt+=1
fig.savefig("image/{}.png".format(epoch))
plt.close()
def save(self):
self.Dis.save('model/Discrimator.h5')
self.Gen.save('model/Generator.h5')
if __name__ == '__main__':
gan=GAN()
gan.train(20000,512,100)
gan.save()
|