6 COGAN(耦合生成对抗网络,1个模型2个用途)
COGAN是一种耦合生成式对抗网络,其内部具有一定的耦合,可以对同一个输入有不同的输出。
其具体实现方式就是: 1、建立两个生成模型,两个判别模型。 2、两个生成模型的特征提取部分有一定的重合,在最后生成图片的部分分开,以生成不同类型的图片。 3、两个判别模型的特征提取部分有一定的重合,在最后判别真伪的部分分开,以判别不同类型的图片。
核心思想是权重共享,生成两种不同分割的图片,一个网络两种用途 相当于一个网络实现了两个网络的功能 COGAN的训练思路分为如下几个步骤: 1、创建两个风格不同的数据集。 2、随机生成batch_size个N维向量,利用两个不同的生成模型生成图片。 3、利用两个判别模型分别对两个不同的生成模型的生成图片进行判别、对两个风格不同的数据集进行随机选取并进行判别。 4、根据两个判别模型的结果与1对比,对两个生成模型进行训练。
from __future__ import print_function, division
import scipy
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D, GlobalAveragePooling2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import numpy as np
class COGAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.num_classes = 10
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
self.d1, self.d2 = self.build_discriminators()
self.d1.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
self.d2.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
self.g1, self.g2 = self.build_generators()
z = Input(shape=(self.latent_dim,))
img1 = self.g1(z)
img2 = self.g2(z)
self.d1.trainable = False
self.d2.trainable = False
valid1 = self.d1(img1)
valid2 = self.d2(img2)
self.combined = Model(z, [valid1, valid2])
self.combined.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
optimizer=optimizer)
def build_generators(self):
noise = Input(shape=(self.latent_dim,))
x = Dense(32 * 7 * 7, activation="relu", input_dim=self.latent_dim)(noise)
x = Reshape((7, 7, 32))(x)
x = Conv2D(64, kernel_size=3, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation("relu")(x)
x = UpSampling2D()(x)
x = Conv2D(128, kernel_size=3, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation("relu")(x)
x = UpSampling2D()(x)
x = Conv2D(128, kernel_size=3, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
feature_repr = Activation("relu")(x)
model = Model(noise, feature_repr)
noise = Input(shape=(self.latent_dim,))
feature_repr = model(noise)
g1 = Conv2D(64, kernel_size=1, padding="same")(feature_repr)
g1 = BatchNormalization(momentum=0.8)(g1)
g1 = Activation("relu")(g1)
g1 = Conv2D(64, kernel_size=3, padding="same")(g1)
g1 = BatchNormalization(momentum=0.8)(g1)
g1 = Activation("relu")(g1)
g1 = Conv2D(64, kernel_size=1, padding="same")(g1)
g1 = BatchNormalization(momentum=0.8)(g1)
g1 = Activation("relu")(g1)
g1 = Conv2D(self.channels, kernel_size=1, padding="same")(g1)
img1 = Activation("tanh")(g1)
g2 = Conv2D(64, kernel_size=1, padding="same")(feature_repr)
g2 = BatchNormalization(momentum=0.8)(g2)
g2 = Activation("relu")(g2)
g2 = Conv2D(64, kernel_size=3, padding="same")(g2)
g2 = BatchNormalization(momentum=0.8)(g2)
g2 = Activation("relu")(g2)
g2 = Conv2D(64, kernel_size=1, padding="same")(g2)
g2 = BatchNormalization(momentum=0.8)(g2)
g2 = Activation("relu")(g2)
g2 = Conv2D(self.channels, kernel_size=1, padding="same")(g2)
img2 = Activation("tanh")(g2)
return Model(noise, img1), Model(noise, img2)
def build_discriminators(self):
img = Input(shape=self.img_shape)
x = Conv2D(64, kernel_size=3, strides=2, padding="same")(img)
x = BatchNormalization(momentum=0.8)(x)
x = Activation("relu")(x)
x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation("relu")(x)
x = Conv2D(64, kernel_size=3, strides=2, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
x = GlobalAveragePooling2D()(x)
feature_repr = Activation("relu")(x)
model = Model(img, feature_repr)
img1 = Input(shape=self.img_shape)
img2 = Input(shape=self.img_shape)
img1_embedding = model(img1)
img2_embedding = model(img2)
validity1 = Dense(1, activation='sigmoid')(img1_embedding)
validity2 = Dense(1, activation='sigmoid')(img2_embedding)
return Model(img1, validity1), Model(img2, validity2)
def train(self, epochs, batch_size=128, sample_interval=50):
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
X1 = X_train[:int(X_train.shape[0] / 2)]
X2 = X_train[int(X_train.shape[0] / 2):]
X2 = scipy.ndimage.interpolation.rotate(X2, 90, axes=(1, 2))
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, X1.shape[0], batch_size)
imgs1 = X1[idx]
imgs2 = X2[idx]
noise = np.random.normal(0, 1, (batch_size, 100))
gen_imgs1 = self.g1.predict(noise)
gen_imgs2 = self.g2.predict(noise)
d1_loss_real = self.d1.train_on_batch(imgs1, valid)
d2_loss_real = self.d2.train_on_batch(imgs2, valid)
d1_loss_fake = self.d1.train_on_batch(gen_imgs1, fake)
d2_loss_fake = self.d2.train_on_batch(gen_imgs2, fake)
d1_loss = 0.5 * np.add(d1_loss_real, d1_loss_fake)
d2_loss = 0.5 * np.add(d2_loss_real, d2_loss_fake)
g_loss = self.combined.train_on_batch(noise, [valid, valid])
print("%d [D1 loss: %f, acc.: %.2f%%] [D2 loss: %f, acc.: %.2f%%] [G loss: %f]" \
% (epoch, d1_loss[0], 100 * d1_loss[1], d2_loss[0], 100 * d2_loss[1], g_loss[0]))
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 4, 4
noise = np.random.normal(0, 1, (r * int(c / 2), 100))
gen_imgs1 = self.g1.predict(noise)
gen_imgs2 = self.g2.predict(noise)
gen_imgs = np.concatenate([gen_imgs1, gen_imgs2])
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig("images/mnist_%d.png" % epoch)
plt.close()
if __name__ == '__main__':
if not os.path.exists("./images"):
os.makedirs("./images")
gan = COGAN()
gan.train(epochs=30000, batch_size=256, sample_interval=200)
训练结果很差,基本判断原因是因为学习率太大,生成网络的损失值一直下不去
7 LSGAN(最小二乘GAN,均方差替换交叉熵)
LSGAN是一种最小二乘GAN。
其主要特点为将loss函数的计算方式由交叉熵更改为均方差。
无论是判别模型的训练,还是生成模型的训练,都需要将交叉熵更改为均方差。
在普通GAN的基础上替换损失函数
7.1 训练思路
LSGAN的训练思路分为如下几个步骤: 1、随机选取batch_size个真实的图片。 2、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。 3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练,训练的loss使用均方差。 4、将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练(与1对比的意思是,如果Discriminator将虚假图片判断为1,说明这个生成的图片很“真实”),这个loss同样使用均方差。
from __future__ import print_function, division
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
class LSGAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
self.generator = self.build_generator()
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
self.discriminator.trainable = False
valid = self.discriminator(img)
self.combined = Model(z, valid)
self.combined.compile(loss='mse', optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1))
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
g_loss = self.combined.train_on_batch(noise, valid)
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
if __name__ == '__main__':
if not os.path.exists("./images"):
os.makedirs("./images")
gan = LSGAN()
gan.train(epochs=30000, batch_size=512, sample_interval=200)
8 CycleGAN(风格转换)
CycleGAN是一种完成图像到图像的转换的一种GAN。
图像到图像的转换是一类视觉和图形问题,其目标是获得输入图像和输出图像之间的映射。
但是,对于许多任务,配对的训练数据将不可用。
CycleGAN提出了一种在没有成对例子的情况下学习将图像从源域X转换到目标域Y的方法。
这样的结构与我们所学过的语义分割的形式非常类似,因此需要先进行下采样后再进行上采样!
8.1 训练思路
CycleGAN的训练思路分为如下几个步骤: 1、创建两个生成模型,一个用于从图片风格A转换成图片风格B,一个用于从图片风格B转换成图片风格A。 2、创建两个判别模型,分别用于风格A图片的真伪判断和风格B图片的真伪判断。 3、判别模型的训练所用的损失函数与LSGAN相同,通过判断是否正确进行训练。 4、生成模型的训练需要满足下面六个准则:
a、从图片风格A转换成图片风格B的假图像需要成功欺骗判断模型B; b、从图片风格B转换成图片风格A的假图像需要成功欺骗判断模型A; c、从图片风格A转换成图片风格B的假图像可以通过生成模型BA成功转换成图片A; d、从图片风格B转换成图片风格A的假图像可以通过生成模型AB成功转换成图片B; e、真实图片A通过生成模型BA,不会发生变化。 f、真实图片B通过生成模型AB,不会发生变化。 其中c、d准则是为了让生成器找到最需要修改的地方,比如 斑马转黄马就只要改变马的颜色就可以欺骗判断模型,风格A的图片经过生成模型AB只需要转化 斑马 即可。 其中e、f准则是为了让 两种生成模型可以区分两种图片风格,生成模型AB只对风格A的图片进行处理,生成模型BA只对风格B的图片进行处理。
这个代码有些问题,自己用tensorflow复现的时候有一些问题直接看下面的博客吧 (8条消息) 好像还挺好玩的GAN7——CycleGAN实现图像风格转换_Bubbliiiing的学习小课堂-CSDN博客_gan风格转换 https://blog.csdn.net/weixin_44791964/article/details/103780922
在这里插入代码片
9 SRGAN(图像分辨率提升GAN)
SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。
其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。
文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节。
SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感。
其中感知损失是利用卷积神经网络提取出的特征,通过比较生成图片经过卷积神经网络后的特征和目标图片经过卷积神经网络后的特征的差别,使生成图片和目标图片在语义和风格上更相似
对抗损失由GAN提供,根据图像是否可以欺骗过判别网络进行训练。
9.1 生成网络
此图从左至右来看,我们可以知道: SRGAN的生成网络由三个部分组成。 1、低分辨率图像进入后会经过一个卷积+RELU函数 2、然后经过B个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边。 3、然后进入上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率。
前两部分用于特征提取,第三部分用于提高分辨率。
9.2 判别网络
SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。
9.3 训练思路
1、对判别模型进行训练 将真实的高分辨率图像和虚假的高分辨率图像传入判别模型中。 将真实的高分辨率图像的判别结果与1对比得到loss。 将虚假的高分辨率图像的判别结果与0对比得到loss。 利用得到的loss进行训练。 两个loss 2、对生成模型进行训练 两个loss 将低分辨率图像传入生成模型,得到高分辨率图像,利用该高分辨率图像获得判别结果与1进行对比得到loss。 将真实的高分辨率图像和虚假的高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss。
3、数据集只需要高分辨率图片就行了,通过直接降维得到低分辨率图片
9.4 几种尺度
(1)512512,这是原始高分辨率图像 (2)128128,这是低分辨率图像 (3)512512,这是低分辨率图像生成的高分辨率图像 (4)3232,这是vgg19得到的特征进行对比
9.5 3个文件代码
第一个文件进行图像数据的读取,
并生成低分辨率图像,默认高分辨率128128,低分辨率3232,但是在本实验中,传递的参数是高分辨率512512,低分辨率128128
import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
class DataLoader():
def __init__(self, dataset_name, img_res=(128, 128)):
self.dataset_name = dataset_name
self.img_res = img_res
def load_data(self, batch_size=1, is_testing=False):
data_type = "train" if not is_testing else "test"
path = glob('./datasets/%s/train/*' % (self.dataset_name))
batch_images = np.random.choice(path, size=batch_size)
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
h, w = self.img_res
low_h, low_w = int(h / 4), int(w / 4)
img_hr = scipy.misc.imresize(img, self.img_res)
img_lr = scipy.misc.imresize(img, (low_h, low_w))
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
return imgs_hr, imgs_lr
def imread(self, path):
return scipy.misc.imread(path, mode='RGB').astype(np.float)
第二个文件,进行模型的训练
遇到scipy 的版本问题报错,根据报错信息显示,直接改版本就行了 conda install scipy==1.2.1 就可以
from __future__ import print_function, division
import scipy
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from tensorflow.keras.layers import PReLU, LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os
import tensorflow.keras.backend as K
class SRGAN():
def __init__(self):
self.channels = 3
self.lr_height = 128
self.lr_width = 128
self.lr_shape = (self.lr_height, self.lr_width, self.channels)
self.hr_height = self.lr_height*4
self.hr_width = self.lr_width*4
self.hr_shape = (self.hr_height, self.hr_width, self.channels)
self.n_residual_blocks = 16
optimizer = Adam(0.0002, 0.5)
self.vgg = self.build_vgg()
self.vgg.trainable = False
self.dataset_name = 'DIV2K_train_HR'
self.data_loader = DataLoader(dataset_name=self.dataset_name,
img_res=(self.hr_height, self.hr_width))
patch = int(self.hr_height / 2**4)
self.disc_patch = (patch, patch, 1)
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
self.discriminator.summary()
self.generator = self.build_generator()
self.generator.summary()
img_lr = Input(shape=self.lr_shape)
fake_hr = self.generator(img_lr)
fake_features = self.vgg(fake_hr)
self.discriminator.trainable = False
validity = self.discriminator(fake_hr)
self.combined = Model(img_lr, [validity, fake_features])
self.combined.compile(loss=['binary_crossentropy', 'mse'],
loss_weights=[5e-1, 1],
optimizer=optimizer)
def build_vgg(self):
vgg = VGG19(weights="imagenet")
vgg.outputs = [vgg.layers[9].output]
img = Input(shape=self.hr_shape)
return Model(vgg.input, vgg.outputs)
def build_generator(self):
def residual_block(layer_input, filters):
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
d = BatchNormalization(momentum=0.8)(d)
d = Activation('relu')(d)
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Add()([d, layer_input])
return d
def deconv2d(layer_input):
u = UpSampling2D(size=2)(layer_input)
u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
u = Activation('relu')(u)
return u
img_lr = Input(shape=self.lr_shape)
c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
c1 = Activation('relu')(c1)
r = residual_block(c1, 64)
for _ in range(self.n_residual_blocks - 1):
r = residual_block(r, 64)
c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
c2 = BatchNormalization(momentum=0.8)(c2)
c2 = Add()([c2, c1])
u1 = deconv2d(c2)
u2 = deconv2d(u1)
gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
return Model(img_lr, gen_hr)
def build_discriminator(self):
def d_block(layer_input, filters, strides=1, bn=True):
"""Discriminator layer"""
d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
d = LeakyReLU(alpha=0.2)(d)
if bn:
d = BatchNormalization(momentum=0.8)(d)
return d
d0 = Input(shape=self.hr_shape)
d1 = d_block(d0, 64, bn=False)
d2 = d_block(d1, 64, strides=2)
d3 = d_block(d2, 128)
d4 = d_block(d3, 128, strides=2)
d5 = d_block(d4, 256)
d6 = d_block(d5, 256, strides=2)
d7 = d_block(d6, 512)
d8 = d_block(d7, 512, strides=2)
d9 = Dense(64*16)(d8)
d10 = LeakyReLU(alpha=0.2)(d9)
validity = Dense(1, activation='sigmoid')(d10)
return Model(d0, validity)
def scheduler(self,models,epoch):
if epoch % 20000 == 0 and epoch != 0:
for model in models:
lr = K.get_value(model.optimizer.lr)
K.set_value(model.optimizer.lr, lr * 0.5)
print("lr changed to {}".format(lr * 0.5))
def train(self, epochs ,init_epoch=0, batch_size=1, sample_interval=50):
start_time = datetime.datetime.now()
if init_epoch!= 0:
self.generator.load_weights("weights/%s/gen_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
self.discriminator.load_weights("weights/%s/dis_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
for epoch in range(init_epoch,epochs):
self.scheduler([self.combined,self.discriminator],epoch)
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
fake_hr = self.generator.predict(imgs_lr)
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
valid = np.ones((batch_size,) + self.disc_patch)
image_features = self.vgg.predict(imgs_hr)
g_loss = self.combined.train_on_batch(imgs_lr, [valid, image_features])
print(d_loss,g_loss)
elapsed_time = datetime.datetime.now() - start_time
print ("[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, feature loss: %05f] time: %s " \
% ( epoch, epochs,
d_loss[0], 100*d_loss[1],
g_loss[1],
g_loss[2],
elapsed_time))
if epoch % sample_interval == 0:
self.sample_images(epoch)
if epoch % 500 == 0 and epoch != init_epoch:
os.makedirs('weights/%s' % self.dataset_name, exist_ok=True)
self.generator.save_weights("weights/%s/gen_epoch%d.h5" % (self.dataset_name, epoch))
self.discriminator.save_weights("weights/%s/dis_epoch%d.h5" % (self.dataset_name, epoch))
def sample_images(self, epoch):
os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
r, c = 2, 2
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
fake_hr = self.generator.predict(imgs_lr)
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5
titles = ['Generated', 'Original']
fig, axs = plt.subplots(r, c)
cnt = 0
for row in range(r):
for col, image in enumerate([fake_hr, imgs_hr]):
axs[row, col].imshow(image[row])
axs[row, col].set_title(titles[col])
axs[row, col].axis('off')
cnt += 1
fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
plt.close()
for i in range(r):
fig = plt.figure()
plt.imshow(imgs_lr[i])
fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
plt.close()
if __name__ == '__main__':
gan = SRGAN()
gan.train(epochs=60000,init_epoch = 0, batch_size=1, sample_interval=50)
第3个文件进行预测
加载模型和权重进行预测
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from tensorflow.keras.layers import PReLU, LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from srgan import SRGAN
from PIL import Image
import numpy as np
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def build_generator():
def residual_block(layer_input, filters):
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
d = Activation('relu')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Add()([d, layer_input])
return d
def deconv2d(layer_input):
u = UpSampling2D(size=2)(layer_input)
u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
u = Activation('relu')(u)
return u
img_lr = Input(shape=[None,None,3])
c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
c1 = Activation('relu')(c1)
r = residual_block(c1, 64)
for _ in range(15):
r = residual_block(r, 64)
c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
c2 = BatchNormalization(momentum=0.8)(c2)
c2 = Add()([c2, c1])
u1 = deconv2d(c2)
u2 = deconv2d(u1)
gen_hr = Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
return Model(img_lr, gen_hr)
model = build_generator()
model.load_weights(r"weights\DIV2K_train_HR\gen_epoch6000.h5")
before_image = Image.open(r"./images/before.png")
new_image = Image.new('RGB', before_image.size, (128,128,128))
new_image.paste(before_image)
new_image = np.array(new_image)/127.5 - 1
print("图像大小:",new_image.shape)
new_image = np.expand_dims(new_image,axis=0)
fake = (model.predict(new_image)*0.5 + 0.5)*255
fake = Image.fromarray(np.uint8(fake[0]))
fake.save("out.png")
fake.show()
10 总结
GAN的基本思想: (1)构建一个普通的生成模型,比如vgg16等或者unet等 (2)构建一个普通的二分类的判别模型; (3)先训练一个batch的判别模型,再训练一个batch的生成模型; (4)训练判别模型就直接根据前向传播,真是图像和生成的假图像各占一半得到判别模型的损失值,训练判别模型。 (5)训练生成模型,需要用到判别模型的结果在这次生成的图像的判别结果,所以要构建从生成到判别的一条流程的模型,得到生成模型的损失值,训练生成模型。 上面的8个模型看懂一个就行
参考资料
【1】Keras 搭建自己的GAN生成对抗网络平台(Bubbliiiing 深度学习 教程)_哔哩哔哩_bilibili https://www.bilibili.com/video/BV13J41187Fo?from=search&seid=14309111542489072351&spm_id_from=333.337.0.0
【2】下面的链接包含的研究很多 生成对抗网络的生成样本能否提高预测模型准确率? - 知乎 https://www.zhihu.com/question/372133109
【3】(8条消息) 好像还挺好玩的GAN_Bubbliiiing的学习小课堂-CSDN博客 https://blog.csdn.net/weixin_44791964/category_9625179.html
|