IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> GAN简介与复习 基于pytorch -> 正文阅读

[人工智能]GAN简介与复习 基于pytorch

概要

GAN(Generative Adversarial Network)生成对抗网络。其实要理解GAN的构想逻辑并不难,像其他的一些模型比如说最最基础的nn.Linear() + nn.ReLU(),或者是RNN模型,我们不妨把这个模型看成一位武侠,他的目的是要跟江湖上尽可能多的人(data)过招(train),目的是在未来遇到邪恶的坏蛋(真实情景应用)时能够一招制敌(给出正确的结果)。
但是天不遂人愿,在茫茫的人海中,真正的武林高手有几个?又有几个能被我遇到?今天打过了丐帮的降龙十八掌,明天谁知道会不会被一记九阳神功拍的头昏眼花?(能接触到的数据总是有限的)武侠仰天沉思,他想起那一年去西域,自己仗着在中原打遍天下无敌手(过拟合)四处张扬得不行,结果被一旁的扫地大爷一套西洋拳术带走(模型不适应其他数据)。
可是家中有老母要照顾,忠孝难两全,武侠也因此一直呆在中原。由于放眼神州已无敌手,便打起了木人桩。机会总是留给有准备的人,有一天武侠捡到了阿拉丁神灯,神灯答应了他的愿望,点化了他的木人桩,让他能主动与武侠打斗并且不断增强自己的武力值,直击武侠痛点。武侠大喜,从此开始了与被点化的木人桩的切磋之路,技艺日增,终成一代地球大侠。

这个木人桩和武侠就是GAN中的Generator和Discriminator。对于Discriminator而言,它的目标是分辨出真假数据,对于Generator而言,它的目标是要制造出能以假乱真的数据。在学习的过程中,Generator的输入我们用torch.randn产生随机数据,以此希望通过Generator产生各种各样的输入。
简单地说,二者的目标总结为:

  • Discriminator: 给定数据x,我希望分辨出这是真实产生的数据,还是Generator模拟的假数据,输出0-1

  • Generator: 给定随机数random,我希望能蒙混过关,尽可能模拟真实数据 output.shape == x.shape

二者在训练的过程中我们应该可以看到两边的loss大致是一个此消彼长的关系,这也是GAN中A(Adversarial)的本意,两者对抗。

一个例子 (base on MNIST)

我用在暑假跟着学深度学习中一课时的代码复现给大家分享一下。MNIST数据集是一个图片集,都是手写的单个数字,有images和labels两个部分。用torch.utils.data.Dataset或者torchvision.Datasets.MNIST可以读入为dataset实例,进一步构造dataloader。废话不多说,上代码。

# 一些经常用的库和函数
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
# 定义一些超参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
latent_size = 64	# 就是上文中生成的random的长度
image_size = 28*28  # 这是MNIST数据集中图片的大小
hidden_size = 256  # 定义Discriminator和Generator模型中的隐层的大小
output_size = 10  # 最终输出为10维的向量,代表1~10每位数字,需要的话也可以用softmax转换为概率
transform = transforms.Compose([
    transforms.ToTensor(),   # 将读进来的图片转为Tensor
    transforms.Normalize(mean=(0.5), std=(0.5))		# 将图片每个像素中心化  (这两个值没有求证)
])
mnist_data = torchvision.datasets.MNIST("./mnist_data", train=True, download=True, transform=transform)
loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)
# Discriminator 的定义
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(),
    nn.Linear(hidden_size, latent_size),
    nn.LeakyReLU(),
    nn.Linear(latent_size, 1),
    nn.Sigmoid()
)

# Generator 的定义
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.LeakyReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)
# 以上两个模型并不很长,如果模型层数较多,建议加上BatchNorm层

D = D.to(device)
G = G.to(device)
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)
loss_fn = nn.BCELoss()  # 之前的output_size=10,注意target的size

在这里有一些不太一样的地方,一般来讲Discriminator的优化要显著快于Generator,所以我这里采用Generator优化3次,Discriminator优化1次的比例来进行optimize。

total_epoch = 200
total_iteration = len(loader)
d_loss = 0.
g_loss = 0.

# optimize ratio
# Discriminator : Generator = 1 : 3

for epoch in range(total_epoch):
    for i, (img, label) in enumerate(loader):
        # Discriminator
        # batch * 1 * 28 * 28
        img = img.to(device)
        label = label.to(device)

        batch_size = img.size(0)
        img = img.reshape(batch_size, -1)

        real_label = torch.ones(batch_size, 1).to(device)
        if i % 3 == 0:
            fake_label = torch.zeros(batch_size, 1).to(device)
            output_real = D(img)
            d_loss_real = loss_fn(output_real, real_label)

            latent = torch.randn(batch_size, latent_size).to(device)
            fake_img = G(latent)
            output_fake = D(fake_img.detach())
            d_loss_fake = loss_fn(output_fake, fake_label)  # 注意与下文中Generator的target做对比

            d_loss = d_loss_real + d_loss_fake

            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

        # Generator
        latent = torch.randn(batch_size, latent_size).to(device)
        fake_img = G(latent)
        output_fake = D(fake_img)
        g_loss = loss_fn(output_fake, real_label)  # 注意对于Generator,target应该是什么

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()


        if i % 300 == 0:
            print("Epoch: [{}/{}], iteration: [{}/{}] d_loss: {:.8f}, d_loss_real: {:.8f}, d_loss_fake: {:.8f}, g_loss:{:.4f}"
            .format(epoch, total_epoch, i, total_iteration, d_loss.item(), d_loss_real.item(), d_loss_fake.item(), g_loss.item()))

部分输出为:
在这里插入图片描述

其他数据源

我还用名人照片合集做了一个GAN,那个数据集我下载的时候叫做img_align_celeba。这个结果比较好,训练速度也相对较快,反正比图片风格迁移快多了(一张32*32的图片跑了我45分钟…),不过最终血小板style的寡姐还是不错的奥。
呸呸呸!
其他的不赘述,整体的思路与实现逻辑与mnist based的情况相差不多,只是在Generator, Discriminator两个nn.Sequential的定义上复杂了一些(加了几个线性层)而已。
提前说明一下,我的GTX 1060没那么大显存,老师的服务器近期又在抽风上不去,所有的图片有24W张,根本放不进去,我就意思了一下放了1.2W,多训了几个epoch,不过最后的结果我相对满意(有人样了哈哈哈哈哈)。
在这里插入图片描述
文末:
有问题希望大家指出!有欠缺的地方希望大家能指出!😛

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-28 22:00:04  更:2021-08-28 22:00:17 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 22:46:33-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码