IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch实现 wgan -> 正文阅读

[人工智能]pytorch实现 wgan

作者:recommend-item-box-tow

在网上找了一个wgan的实现代码,在本地跑了以下,效果还可以,我把它封装成一个函数了,感兴趣的朋友可以用一下

不过这个gan生成的是一维数据,对于图片数据可能需要对代码进行一些改变

import numpy as np
import pandas as pd
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(1)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings("ignore")

def train_model_save_gen(data, ITERS = 600, iter_ctrl=200, use_cuda = False, name_save='', file_name='./save_gen/'):

    if not os.path.exists(file_name):
        os.mkdir(file_name)
    FIXED_GENERATOR = False
    LAMBDA = .1
    CRITIC_ITERS = 5
    CRITIC_ITERG = 1
    BATCH_SIZE = len(data)

    class Generator(nn.Module):

        def __init__(self, shape1):
            super(Generator, self).__init__()

            main = nn.Sequential(
                nn.Linear(shape1, 1024),
                nn.ReLU(True),
                nn.Linear(1024, 512),
                nn.ReLU(True),
                nn.Linear(512, 256),
                nn.ReLU(True),
                nn.Linear(256, 512),
                nn.ReLU(True),
                nn.Linear(512, 1024),
                nn.Tanh(),
                nn.Linear(1024, shape1),
            )
            self.main = main

        def forward(self, noise, real_data):
            if FIXED_GENERATOR:
                return noise + real_data
            else:
                output = self.main(noise)
                return output

    class Discriminator(nn.Module):

        def __init__(self, shape1):
            super(Discriminator, self).__init__()

            self.fc1 = nn.Linear(shape1, 512)
            self.relu1 = nn.LeakyReLU(0.2)
            self.fc2 = nn.Linear(512, 256)
            self.relu2 = nn.LeakyReLU(0.2)
            self.fc3 = nn.Linear(256, 256)
            self.relu3 = nn.LeakyReLU(0.2)
            self.fc4 = nn.Linear(256, 128)
            self.relu4 = nn.LeakyReLU(0.2)
            self.fc5 = nn.Linear(128, 1)

        def forward(self, inputs):
            out = self.fc1(inputs)
            out = self.relu1(out)
            out = self.fc2(out)
            out = self.relu2(out)
            out = self.fc3(out)
            out = self.relu3(out)
            out = self.fc4(out)
            out = self.relu4(out)
            out = self.fc5(out)
            return out.view(-1)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.fill_(0)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def calc_gradient_penalty(netD, real_data, fake_data):
        alpha = torch.rand(BATCH_SIZE, 1)
        alpha = alpha.expand(real_data.size())
        alpha = alpha.cuda() if use_cuda else alpha
        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
        if use_cuda:
            interpolates = interpolates.cuda()
        interpolates = autograd.Variable(interpolates, requires_grad=True)
        disc_interpolates = netD(interpolates)
        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
                                      disc_interpolates.size()), create_graph=True, retain_graph=True,
                                  only_inputs=True)[0]
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
        return gradient_penalty

    netG = Generator(data.shape[1])
    netD = Discriminator(data.shape[1])
    netD.apply(weights_init)
    netG.apply(weights_init)
    if use_cuda:
        netD = netD.cuda()
        netG = netG.cuda()

    optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))

    one = torch.tensor(1, dtype=torch.float)  ###torch.FloatTensor([1])
    mone = one * -1
    if use_cuda:
        one = one.cuda()
        mone = mone.cuda()

    ### ### ###
    one_list = np.ones((data.shape[0]))
    zero_list = np.zeros((data.shape[0]))
    opt_diff_accuracy_05 = 0.5
    best_item = 0
    opt_accuracy = 0
    all_result = []

    loss_list = {'D_loss':[], 'G_loss':[]}
    for iteration in range(ITERS):
        sys.stdout.write(f'\r进行:{iteration}/{ITERS}')  # \r 默认表示将输出的内容返回到第一个指针,这样的话,后面的内容会覆盖前面的内容。
        sys.stdout.flush()
        for p in netD.parameters():
            p.requires_grad = True
        # data = inf_train_gen('data_GAN')
        real_data = torch.FloatTensor(data)
        if use_cuda:
            real_data = real_data.cuda()
            false_data = false_data.cuda()
        real_data_v = autograd.Variable(real_data)
        false_data_v = autograd.Variable(false_data)

        noise = torch.randn(BATCH_SIZE, data.shape[1])
        if use_cuda:
            noise = noise.cuda()

        noisev = autograd.Variable(noise, volatile=True)
        fake = autograd.Variable(netG(noisev, real_data_v).data)
        fake_output = fake.data.cpu().numpy()
        for iter_d in range(CRITIC_ITERS):
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)  ##############
            noise = torch.randn(BATCH_SIZE, data.shape[1])
            if use_cuda:
                noise = noise.cuda()
            noisev = autograd.Variable(noise, volatile=True)  # volatile=True相当于 requires_grad=False
            fake = autograd.Variable(netG(noisev, real_data_v).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)  ################

            gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
            gradient_penalty.backward()  ############
            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            loss_list['D_loss'].append(D_cost.item())
            optimizerD.step()

        if not FIXED_GENERATOR:
            for p in netD.parameters():
                p.requires_grad = False
            for iter_g in range(CRITIC_ITERG):
                netG.zero_grad()
                real_data = torch.Tensor(data)
                if use_cuda:
                    real_data = real_data.cuda()

                real_data_v = autograd.Variable(real_data)
                noise = torch.randn(BATCH_SIZE, data.shape[1])
                if use_cuda:
                    noise = noise.cuda()

                noisev = autograd.Variable(noise)
                fake = netG(noisev, real_data_v)
                G = netD(fake)
                G = G.mean()
                G.backward(mone)
                G_cost = -G
                loss_list['G_loss'].append(G_cost.item())
                optimizerG.step()

        ###save generated sample features every 200 iteration
        if iteration % iter_ctrl == 0:
            if iteration % 10000 == 0:
                data = shuffle(data)
            # save_temp = pd.DataFrame(fake_output)
            # # fake_writer = open(file_name + "/Iteration_" + str(iteration) + ".txt", "w")
            # save_temp.to_csv(file_name + "/Iteration_" + str(iteration) + ".csv", index=None)

            print()
            print(f'循环{iteration}次..')

            x = np.concatenate((data, fake_output), axis=0)
            y = np.concatenate((one_list, zero_list), axis=0)
            kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
            real_label = np.zeros((x.shape[0]))
            pred_label = np.zeros((x.shape[0]))

            for train_index, test_index in kfold.split(x, y):
                x_train, x_test = x[train_index], x[test_index]
                y_train, y_test = y[train_index], y[test_index]
                knn = KNeighborsClassifier(n_neighbors=1).fit(x_train, y_train)
                predicted_y = knn.predict(x_test)
                pred_label[test_index] = predicted_y
                real_label[test_index] = y_test
            accuracy = accuracy_score(real_label, pred_label)
            all_result.append(str(iteration) + "," + str(accuracy))
            print(f'计算{iteration}的acc={accuracy}')

            diff_accuracy_05 = abs(accuracy - 0.5)
            if diff_accuracy_05 < opt_diff_accuracy_05:
                opt_diff_accuracy_05 = diff_accuracy_05
                best_item = iteration
                opt_accuracy = accuracy
                save_temp = pd.DataFrame(fake_output)
                # fake_writer = open(file_name + "/Iteration_" + str(iteration) + ".txt", "w")
                save_temp.to_csv(file_name + "/Iteration3_"+ str(name_save) +'_'+ str(iteration) + ".csv",index=None)

            torch.save(netG.state_dict(), './model_file/netG'+str(iteration)+'.dict')
            torch.save(netD.state_dict(), './model_file/netD'+str(iteration)+'.dict')

    save_loss = pd.DataFrame(loss_list['G_loss'])
    save_loss.to_csv(file_name + "/Gloss_" + str(iteration) + '.csv', index=None)
    save_loss = pd.DataFrame(loss_list['D_loss'])
    save_loss.to_csv(file_name + "/Dloss_" + str(iteration) + '.csv', index=None)

    return best_item,opt_diff_accuracy_05

调用上述函数即可

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-07 22:41:41  更:2022-04-07 22:43:49 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:19:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码