概要
GAN(Generative Adversarial Network)生成对抗网络。其实要理解GAN的构想逻辑并不难,像其他的一些模型比如说最最基础的nn.Linear() + nn.ReLU() ,或者是RNN模型,我们不妨把这个模型看成一位武侠,他的目的是要跟江湖上尽可能多的人(data)过招(train),目的是在未来遇到邪恶的坏蛋(真实情景应用)时能够一招制敌(给出正确的结果)。 但是天不遂人愿,在茫茫的人海中,真正的武林高手有几个?又有几个能被我遇到?今天打过了丐帮的降龙十八掌,明天谁知道会不会被一记九阳神功拍的头昏眼花?(能接触到的数据总是有限的)武侠仰天沉思,他想起那一年去西域,自己仗着在中原打遍天下无敌手(过拟合)四处张扬得不行,结果被一旁的扫地大爷一套西洋拳术带走(模型不适应其他数据)。 可是家中有老母要照顾,忠孝难两全,武侠也因此一直呆在中原。由于放眼神州已无敌手,便打起了木人桩。机会总是留给有准备的人,有一天武侠捡到了阿拉丁神灯,神灯答应了他的愿望,点化了他的木人桩,让他能主动与武侠打斗并且不断增强自己的武力值,直击武侠痛点。武侠大喜,从此开始了与被点化的木人桩的切磋之路,技艺日增,终成一代地球大侠。
这个木人桩和武侠就是GAN中的Generator和Discriminator。对于Discriminator而言,它的目标是分辨出真假数据,对于Generator而言,它的目标是要制造出能以假乱真的数据。在学习的过程中,Generator的输入我们用torch.randn 产生随机数据,以此希望通过Generator产生各种各样的输入。 简单地说,二者的目标总结为:
二者在训练的过程中我们应该可以看到两边的loss大致是一个此消彼长的关系,这也是GAN中A(Adversarial)的本意,两者对抗。
一个例子 (base on MNIST)
我用在暑假跟着学深度学习中一课时的代码复现给大家分享一下。MNIST数据集是一个图片集,都是手写的单个数字,有images和labels两个部分。用torch.utils.data.Dataset 或者torchvision.Datasets.MNIST 可以读入为dataset 实例,进一步构造dataloader 。废话不多说,上代码。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
latent_size = 64
image_size = 28*28
hidden_size = 256
output_size = 10
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5))
])
mnist_data = torchvision.datasets.MNIST("./mnist_data", train=True, download=True, transform=transform)
loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)
D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.LeakyReLU(),
nn.Linear(hidden_size, latent_size),
nn.LeakyReLU(),
nn.Linear(latent_size, 1),
nn.Sigmoid()
)
G = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.LeakyReLU(),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh()
)
D = D.to(device)
G = G.to(device)
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)
loss_fn = nn.BCELoss()
在这里有一些不太一样的地方,一般来讲Discriminator的优化要显著快于Generator,所以我这里采用Generator优化3次,Discriminator优化1次的比例来进行optimize。
total_epoch = 200
total_iteration = len(loader)
d_loss = 0.
g_loss = 0.
for epoch in range(total_epoch):
for i, (img, label) in enumerate(loader):
img = img.to(device)
label = label.to(device)
batch_size = img.size(0)
img = img.reshape(batch_size, -1)
real_label = torch.ones(batch_size, 1).to(device)
if i % 3 == 0:
fake_label = torch.zeros(batch_size, 1).to(device)
output_real = D(img)
d_loss_real = loss_fn(output_real, real_label)
latent = torch.randn(batch_size, latent_size).to(device)
fake_img = G(latent)
output_fake = D(fake_img.detach())
d_loss_fake = loss_fn(output_fake, fake_label)
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
latent = torch.randn(batch_size, latent_size).to(device)
fake_img = G(latent)
output_fake = D(fake_img)
g_loss = loss_fn(output_fake, real_label)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if i % 300 == 0:
print("Epoch: [{}/{}], iteration: [{}/{}] d_loss: {:.8f}, d_loss_real: {:.8f}, d_loss_fake: {:.8f}, g_loss:{:.4f}"
.format(epoch, total_epoch, i, total_iteration, d_loss.item(), d_loss_real.item(), d_loss_fake.item(), g_loss.item()))
部分输出为:
其他数据源
我还用名人照片合集做了一个GAN,那个数据集我下载的时候叫做img_align_celeba。这个结果比较好,训练速度也相对较快,反正比图片风格迁移快多了(一张32*32的图片跑了我45分钟…),不过最终血小板style的寡姐还是不错的奥。 呸呸呸! 其他的不赘述,整体的思路与实现逻辑与mnist based的情况相差不多,只是在Generator, Discriminator两个nn.Sequential 的定义上复杂了一些(加了几个线性层)而已。 提前说明一下,我的GTX 1060没那么大显存,老师的服务器近期又在抽风上不去,所有的图片有24W张,根本放不进去,我就意思了一下放了1.2W,多训了几个epoch,不过最后的结果我相对满意(有人样了哈哈哈哈哈)。 文末: 有问题希望大家指出!有欠缺的地方希望大家能指出!😛
|