什么是GAN
GAN(Generative Adversarial Network),网络也如他的名字一样,有生成,有对抗,两个网络相互博弈。我们给两个网络起个名字,第一个网络用来生成数据命名为生成器(generator),另一个网络用来鉴别生成器生成的数据我们命名为鉴别器(discriminator)。
GAN的训练
标准GAN的训练有三步:
- 用真实的训练数据训练鉴别器
- 用生成的数据训练鉴别器
- 训练生成器生成数据,并使鉴别器以为是真实数据
数据集
经典mnist数据集,典中典了,不放了,网上很多。
代码
多数代码来自《Pytorch生成对抗网络编程》人民邮电出版社 有些书上的方法我不是很习惯,也重构了很多,最后效果都差不多。 已修复模式崩坏等问题
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data as Data
from sklearn.preprocessing import OneHotEncoder
import scipy.io as scio
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
mnist_dataset = pd.read_csv('mnist_train.csv', header=None).values
label = mnist_dataset[:, 0]
image_values = mnist_dataset[:, 1:] / 255.0
encoder = OneHotEncoder(sparse=False)
label = encoder.fit_transform(label.reshape(-1, 1))
train_t = torch.from_numpy(image_values.astype(np.float32))
label = torch.from_numpy(label.astype(np.float32))
train_data = Data.TensorDataset(train_t, label)
train_loader = Data.DataLoader(dataset=train_data,
batch_size=1,
shuffle=True)
def plot_num_image(index):
plt.imshow(image_values[index].reshape(28, 28), cmap='gray')
plt.title('label=' + str(label[index]))
plt.show()
def generate_random(size):
random_data = torch.rand(size)
return random_data
def generate_random_seed(size):
random_data = torch.randn(size)
return random_data
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 300),
nn.LeakyReLU(0.02),
nn.LayerNorm(300),
nn.Linear(300, 30),
nn.LeakyReLU(0.02),
nn.LayerNorm(30),
nn.Linear(30, 1),
nn.Sigmoid(),
)
self.loss_function = nn.BCELoss()
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.01)
self.counter = 0
self.progress = []
def forward(self, inputs):
return self.model(inputs)
def train(self, inputs, targets):
outputs = self.forward(inputs)
loss = self.loss_function(outputs, targets)
self.counter += 1
if self.counter % 10 == 0:
self.progress.append(loss.item())
if self.counter % 10000 == 0:
print("counter = ", self.counter)
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 300),
nn.LeakyReLU(0.02),
nn.LayerNorm(300),
nn.Linear(300, 784),
nn.Sigmoid(),
)
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.01)
self.counter = 0
self.progress = []
def forward(self, inputs):
return self.model(inputs)
def train(self, D, inputs, targets):
g_output = self.forward(inputs)
d_output = D.forward(g_output)
loss = D.loss_function(d_output, targets)
self.counter += 1
if self.counter % 10 == 0:
self.progress.append(loss.item())
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
D = Discriminator()
G = Generator()
'''
for step, (b_x, b_y) in enumerate(train_loader):
# 真实数据
D.train(b_x[0], torch.FloatTensor([1.0]))
# 生成数据
D.train(generate_random(784), torch.FloatTensor([0.0]))
plt.plot(D.progress) # loss很快就归0了
plt.show()
# 输出一个真是数据和生成数据
print('real_num:', D.forward(b_x[0]).item())
print('generate-num:', D.forward(generate_random(784)).item())
# 至此我们的鉴别器已经学会分类真实数据和我们随机生成的数据了
# 让生成器随机产生一个图像我们看看
output = G.forward(generate_random(1))
img = output.detach().numpy().reshape(28, 28)
plt.imshow(img, interpolation='none', cmap='gray') # interpolation 差值方法
plt.show()
'''
for epoch in range(10):
for step, (b_x, b_y) in enumerate(train_loader):
D.train(b_x[0], torch.FloatTensor([1.0]))
D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))
G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))
print('完成',epoch+1,'epoch','*************'*3)
plt.plot(D.progress, c='b', label='D-loss')
plt.plot(G.progress, c='r', label='G-loss')
plt.legend()
plt.savefig('loss.jpg')
plt.show()
for i in range(6):
output = G.forward(generate_random_seed(100))
img = output.detach().numpy().reshape(28, 28)
plt.subplot(2, 3, i+1)
plt.imshow(img,cmap='gray')
plt.show()
我们生成几张图像看看:
for i in range(6):
output = G.forward(generate_random_seed(100))
img = output.detach().numpy().reshape(28, 28)
plt.subplot(2, 3, i+1)
plt.imshow(img, cmap='gray')
plt.show()
看着很像000038,非常好了,生成器并没有见过数字长什么样子,但是他学会了怎么写(生成)相似的图像。刚开始学GAN不久,至此我们的生成器也只是能随机生成图像,无法生成特定的数字。 还没想到怎么解决。
|