概述
GAN 神经网络(2014)
- 《Generative Adversarial Network》
- https://arxiv.org/pdf/1406.2661.pdf
GAN网络也是一直在进步
原理
GAN训练可以分为5步
- 样本图片x输入分类器
- 随机种子z经由生成器生成模拟图片x*
- 模拟图片x*输入分类器
- 反馈给分类器
- 反馈给生成器
MXNet代码
导入库
import time
import gzip
import numpy as np
import matplotlib.pyplot as plt![请添加图片描述](https://img-blog.csdnimg.cn/153a1da1e6d8438681cddca1fae72021.bmp)
import mxnet as mx
环境和批大小
batch_size = 10
device = mx.cpu()
读取训练数据
def load_dataset():
transform = mx.gluon.data.vision.transforms.ToTensor()
train_img = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]
train_lbl = [ np.array(lbl) for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]
eval_img = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]
eval_lbl = [ np.array(lbl) for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]
return train_img, train_lbl, eval_img, eval_lbl
train_img, train_lbl, eval_img, eval_lbl = load_dataset()
train_data = mx.gluon.data.DataLoader(
mx.gluon.data.ArrayDataset(train_img, train_lbl),
batch_size=batch_size,
shuffle=True
)
eval_data = mx.gluon.data.DataLoader(
mx.gluon.data.ArrayDataset(eval_img, eval_lbl),
batch_size=batch_size,
shuffle=False
)
预览训练数据
idxs = (25, 47, 74, 88, 92)
for i in range(5):
plt.subplot(1, 5, i + 1)
idx = idxs[i]
plt.xticks([])
plt.yticks([])
img = train_img[idx][0].astype( np.float32 )
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()
分类器
class Discriminator():
def __init__(self):
self.loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
self.metric = mx.metric.Loss()
self.net = mx.gluon.nn.HybridSequential()
self.net.add(
mx.gluon.nn.Flatten(),
mx.gluon.nn.Dense(units=200),
mx.gluon.nn.LeakyReLU(alpha=0.02),
mx.gluon.nn.LayerNorm()
)
self.net.add(
mx.gluon.nn.Dense(units=1),
)
self.net.initialize( init=mx.init.Xavier(rnd_type='gaussian'), ctx=device )
self.trainer = mx.gluon.Trainer(
params=self.net.collect_params(),
optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True)
)
self.net.summary(mx.ndarray.zeros(shape=(50, 1, 28, 28), dtype=np.float32, ctx=device))
discriminator = Discriminator()
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (50, 1, 28, 28) 0
Flatten-1 (50, 784) 0
Dense-2 (50, 200) 157000
LeakyReLU-3 (50, 200) 0
LayerNorm-4 (50, 200) 400
Dense-5 (50, 1) 201
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 157601
Trainable params: 157601
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 157601
--------------------------------------------------------------------------------
生成器
class Generator():
def __init__(self):
self.metric = mx.metric.Loss()
self.net = mx.gluon.nn.HybridSequential()
self.net.add(
mx.gluon.nn.Dense(units=200),
mx.gluon.nn.LeakyReLU(alpha=0.02),
mx.gluon.nn.LayerNorm(),
)
self.net.add(
mx.gluon.nn.Dense(units=784),
mx.gluon.nn.Activation(activation='sigmoid'),
mx.gluon.nn.HybridLambda(lambda F, x: F.reshape(x, shape=(0, -1, 28, 28)))
)
self.net.initialize( init=mx.init.Xavier(rnd_type='gaussian'), ctx=device )
self.trainer = mx.gluon.Trainer(
params=self.net.collect_params(),
optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True)
)
self.net.summary(mx.ndarray.zeros(shape=(50, 100), dtype=np.float32, ctx=device))
generator = Generator()
生成种子
def make_seed(batach, num, device=None):
if device is None:
device = mx.cpu()
return mx.ndarray.normal(loc=0, scale=1, shape=(batach, num), ctx=device)
训练
for epoch in range(120):
discriminator.metric.reset(); discriminator.metric.reset(); tic = time.time()
for datas, _ in train_data:
actually_batch_size = datas.shape[0]
datas = mx.gluon.utils.split_and_load( datas, [device] )
seeds = [make_seed(data.shape[0], 100, device) for data in datas]
for data, seed in zip(datas, seeds):
length = data.shape[0]
lbl_real = mx.ndarray.ones(shape=(length,1), ctx=device)
lbl_fake = mx.ndarray.zeros(shape=(length,1), ctx=device)
with mx.autograd.record():
a = discriminator.loss_fn(discriminator.net(data), lbl_real)
b = discriminator.loss_fn(discriminator.net(generator.net(seed).detach()), lbl_fake)
d_loss = a + b
d_loss.backward()
discriminator.metric.update(_, preds=d_loss)
discriminator.trainer.step(actually_batch_size)
for seed in seeds:
length = seed.shape[0]
lbl_real = mx.ndarray.ones(shape=(length,1), ctx=device)
with mx.autograd.record():
img = generator.net(seed)
g_loss = discriminator.loss_fn(discriminator.net(img), lbl_real)
g_loss.backward()
generator.metric.update(_, preds=g_loss)
generator.trainer.step(actually_batch_size)
print("Epoch {:>2d}: cost:{:.1f}s d_loss:{:.3f} g_loss:{:.3f}".format(epoch, time.time()-tic, discriminator.metric.get()[1], generator.metric.get()[1]))
for i in range(10):
plt.subplot(1, 10, i + 1)
plt.xticks([]); plt.yticks([])
seed = make_seed(length, 100, device)
img = generator.net(seed).asnumpy()
plt.imshow(img[0][0], interpolation='none', cmap='Blues')
plt.show()
Epoch 0: cost:43.4s d_loss:0.372 g_loss:4.094 Epoch 1: cost:40.2s d_loss:0.602 g_loss:3.242 Epoch 2: cost:40.5s d_loss:0.661 g_loss:2.887 Epoch 3: cost:40.2s d_loss:0.701 g_loss:2.688 Epoch 4: cost:40.3s d_loss:0.707 g_loss:2.563 Epoch 5: cost:39.2s d_loss:0.708 g_loss:2.480 Epoch 6: cost:52.1s d_loss:0.703 g_loss:2.422 Epoch 7: cost:44.2s d_loss:0.691 g_loss:2.384 Epoch 8: cost:53.6s d_loss:0.675 g_loss:2.359 Epoch 9: cost:38.2s d_loss:0.659 g_loss:2.344 Epoch 10: cost:38.2s d_loss:0.645 g_loss:2.333 Epoch 11: cost:38.2s d_loss:0.640 g_loss:2.325 Epoch 12: cost:38.1s d_loss:0.625 g_loss:2.319 Epoch 13: cost:38.3s d_loss:0.619 g_loss:2.313 Epoch 14: cost:38.3s d_loss:0.610 g_loss:2.308 Epoch 15: cost:38.2s d_loss:0.606 g_loss:2.304 Epoch 16: cost:38.2s d_loss:0.598 g_loss:2.301 Epoch 17: cost:38.3s d_loss:0.590 g_loss:2.298 Epoch 18: cost:38.2s d_loss:0.589 g_loss:2.296 Epoch 19: cost:38.5s d_loss:0.585 g_loss:2.294 Epoch 20: cost:38.3s d_loss:0.585 g_loss:2.292
训练20轮输出如图 损失值评价
最理想的分离器损失值为 entropy = -1 x ln(0.5) = 0.693
训练早期(前6轮),分类器损失值迅速从0.372上升到0.708。同时生成器损失值也在迅速下降。此时生成器轻微领先。
随后通过继续学习(7-26轮),分类器和生成器的损失值都在下降,双方不断进步。
之后开始(27轮),分类器损失继续下降,生成器损失上升,说明分类器已经抛开生成器。
|