在网上找了一个wgan的实现代码,在本地跑了以下,效果还可以,我把它封装成一个函数了,感兴趣的朋友可以用一下
不过这个gan生成的是一维数据,对于图片数据可能需要对代码进行一些改变
import numpy as np
import pandas as pd
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(1)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings("ignore")
def train_model_save_gen(data, ITERS = 600, iter_ctrl=200, use_cuda = False, name_save='', file_name='./save_gen/'):
if not os.path.exists(file_name):
os.mkdir(file_name)
FIXED_GENERATOR = False
LAMBDA = .1
CRITIC_ITERS = 5
CRITIC_ITERG = 1
BATCH_SIZE = len(data)
class Generator(nn.Module):
def __init__(self, shape1):
super(Generator, self).__init__()
main = nn.Sequential(
nn.Linear(shape1, 1024),
nn.ReLU(True),
nn.Linear(1024, 512),
nn.ReLU(True),
nn.Linear(512, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.Tanh(),
nn.Linear(1024, shape1),
)
self.main = main
def forward(self, noise, real_data):
if FIXED_GENERATOR:
return noise + real_data
else:
output = self.main(noise)
return output
class Discriminator(nn.Module):
def __init__(self, shape1):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(shape1, 512)
self.relu1 = nn.LeakyReLU(0.2)
self.fc2 = nn.Linear(512, 256)
self.relu2 = nn.LeakyReLU(0.2)
self.fc3 = nn.Linear(256, 256)
self.relu3 = nn.LeakyReLU(0.2)
self.fc4 = nn.Linear(256, 128)
self.relu4 = nn.LeakyReLU(0.2)
self.fc5 = nn.Linear(128, 1)
def forward(self, inputs):
out = self.fc1(inputs)
out = self.relu1(out)
out = self.fc2(out)
out = self.relu2(out)
out = self.fc3(out)
out = self.relu3(out)
out = self.fc4(out)
out = self.relu4(out)
out = self.fc5(out)
return out.view(-1)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def calc_gradient_penalty(netD, real_data, fake_data):
alpha = torch.rand(BATCH_SIZE, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda() if use_cuda else alpha
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
if use_cuda:
interpolates = interpolates.cuda()
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
disc_interpolates.size()), create_graph=True, retain_graph=True,
only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty
netG = Generator(data.shape[1])
netD = Discriminator(data.shape[1])
netD.apply(weights_init)
netG.apply(weights_init)
if use_cuda:
netD = netD.cuda()
netG = netG.cuda()
optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
one = torch.tensor(1, dtype=torch.float)
mone = one * -1
if use_cuda:
one = one.cuda()
mone = mone.cuda()
one_list = np.ones((data.shape[0]))
zero_list = np.zeros((data.shape[0]))
opt_diff_accuracy_05 = 0.5
best_item = 0
opt_accuracy = 0
all_result = []
loss_list = {'D_loss':[], 'G_loss':[]}
for iteration in range(ITERS):
sys.stdout.write(f'\r进行:{iteration}/{ITERS}')
sys.stdout.flush()
for p in netD.parameters():
p.requires_grad = True
real_data = torch.FloatTensor(data)
if use_cuda:
real_data = real_data.cuda()
false_data = false_data.cuda()
real_data_v = autograd.Variable(real_data)
false_data_v = autograd.Variable(false_data)
noise = torch.randn(BATCH_SIZE, data.shape[1])
if use_cuda:
noise = noise.cuda()
noisev = autograd.Variable(noise, volatile=True)
fake = autograd.Variable(netG(noisev, real_data_v).data)
fake_output = fake.data.cpu().numpy()
for iter_d in range(CRITIC_ITERS):
netD.zero_grad()
D_real = netD(real_data_v)
D_real = D_real.mean()
D_real.backward(mone)
noise = torch.randn(BATCH_SIZE, data.shape[1])
if use_cuda:
noise = noise.cuda()
noisev = autograd.Variable(noise, volatile=True)
fake = autograd.Variable(netG(noisev, real_data_v).data)
inputv = fake
D_fake = netD(inputv)
D_fake = D_fake.mean()
D_fake.backward(one)
gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
gradient_penalty.backward()
D_cost = D_fake - D_real + gradient_penalty
Wasserstein_D = D_real - D_fake
loss_list['D_loss'].append(D_cost.item())
optimizerD.step()
if not FIXED_GENERATOR:
for p in netD.parameters():
p.requires_grad = False
for iter_g in range(CRITIC_ITERG):
netG.zero_grad()
real_data = torch.Tensor(data)
if use_cuda:
real_data = real_data.cuda()
real_data_v = autograd.Variable(real_data)
noise = torch.randn(BATCH_SIZE, data.shape[1])
if use_cuda:
noise = noise.cuda()
noisev = autograd.Variable(noise)
fake = netG(noisev, real_data_v)
G = netD(fake)
G = G.mean()
G.backward(mone)
G_cost = -G
loss_list['G_loss'].append(G_cost.item())
optimizerG.step()
if iteration % iter_ctrl == 0:
if iteration % 10000 == 0:
data = shuffle(data)
print()
print(f'循环{iteration}次..')
x = np.concatenate((data, fake_output), axis=0)
y = np.concatenate((one_list, zero_list), axis=0)
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
real_label = np.zeros((x.shape[0]))
pred_label = np.zeros((x.shape[0]))
for train_index, test_index in kfold.split(x, y):
x_train, x_test = x[train_index], x[test_index]
y_train, y_test = y[train_index], y[test_index]
knn = KNeighborsClassifier(n_neighbors=1).fit(x_train, y_train)
predicted_y = knn.predict(x_test)
pred_label[test_index] = predicted_y
real_label[test_index] = y_test
accuracy = accuracy_score(real_label, pred_label)
all_result.append(str(iteration) + "," + str(accuracy))
print(f'计算{iteration}的acc={accuracy}')
diff_accuracy_05 = abs(accuracy - 0.5)
if diff_accuracy_05 < opt_diff_accuracy_05:
opt_diff_accuracy_05 = diff_accuracy_05
best_item = iteration
opt_accuracy = accuracy
save_temp = pd.DataFrame(fake_output)
save_temp.to_csv(file_name + "/Iteration3_"+ str(name_save) +'_'+ str(iteration) + ".csv",index=None)
torch.save(netG.state_dict(), './model_file/netG'+str(iteration)+'.dict')
torch.save(netD.state_dict(), './model_file/netD'+str(iteration)+'.dict')
save_loss = pd.DataFrame(loss_list['G_loss'])
save_loss.to_csv(file_name + "/Gloss_" + str(iteration) + '.csv', index=None)
save_loss = pd.DataFrame(loss_list['D_loss'])
save_loss.to_csv(file_name + "/Dloss_" + str(iteration) + '.csv', index=None)
return best_item,opt_diff_accuracy_05
调用上述函数即可
|