IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 什么是对抗网络,对抗网络能干什么,对抗网络简述。 -> 正文阅读

[人工智能]什么是对抗网络,对抗网络能干什么,对抗网络简述。

一、什么是对抗网络:

生成式对抗网络(Generative adversarial network, GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。


二、对抗网络能干什么:

(1)数据生成,主要指图像生成。图像生成:基于训练的模型,生成类似于训练集的新的图片。

(2)图像数据增强:增强图像中的有用信息,改善图像的视觉效果。

(3)图像外修复:从受限输入图像生成具有语义意义结构的新的视觉和谐内容来扩展图像边界。

(4)图像超分辨率:由一幅低分辨率图像或图像序列恢复出高分辨率图像。

(5)图像风格迁移:通过某种方法,把图像从原风格转换到另一个风格,同时保证图像内容没有变化。

(6)语音合成

意义:GAN网络可以帮助我们建立模型,相比于在已有模型上进行参数更新的传统网络,更具研究价值。


三、对抗网络由哪些部分组成:

(1)生成器(Generator):生成器要不断优化自己生成的数据让判别器判断不出来。

(2)判别器(Discriminator):判别器要进行优化让自己判断更准确

二者关系形成对抗因此叫生成式对抗网络。


接下来我简述下,对抗网络的过程是怎么走的,这是重点:

先给大家说下什么是BCE_LOSS(二元交叉熵):

他是一个专注与做二分类任务的损失函数,目的是求损失,梯度更新,在这里,里面weight(权重参数)不用写。建议大家去搜下这损失函数。

第一步:

先生成一组标签分别是0和1,稍后用作BCE_LOSS损失的输入。

第二步:

训练判别器

会先把真实数据送入判别模型,会返回一个值,然后我们把这个值,和真实值打的标签1求BCE_LOSS损失。

然后把假的的数据(噪音)送入生成模型,也会返回一个值,我们再把这个值,和假的标签0求BECE_LOSS损失。

最后把真实值损失假数据的损失加到一起,一起求梯度,进行更新

第三步:

训练生成器

因为我们在判别阶段已经更新了生成器的参数所以可以直接再次更新(其实就是参数共享)。

最后:可以根据对抗效果设置迭代次数。

可以参考图片理解,如果不行可以翻译看代码

我下面有一个用对抗网络生成图片的代码,大家可以参考参考:

可以直接复制到pycharm,需要改下手写体数据路径:

?

import os
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 优选超参数

latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
bathc_size = 100
sample_dir = 'samples'

# 如果不存在目录创建一个目录
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# 图像处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],
                         std=[0.5])
])

# 加载手写体数据集
mnist = torchvision.datasets.MNIST(
    root=r'C:\Users\qiuhongsen\PycharmProjects\pythonProject\NLP--2\My self dai\tensorflow1\MNIST_data',
    train=True,
    transform=transform,
    download=True)

# 数据加载器
data_loader = DataLoader(dataset=mnist,
                         batch_size=bathc_size,
                         shuffle=True)

# 判别器
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid()
)
# 生成器
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

# 二分类交叉熵损失函数
loss = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)


def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

# 优化器初始化
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()


# 开始训练
total_step = len(data_loader)

for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(bathc_size, -1).to(device)
        # 创建标签,稍后用作BCE丢失的输入
        real_labels = torch.ones(bathc_size, 1).to(device)
        fake_labels = torch.zeros(bathc_size, 1).to(device)
        # ================================================================== #
        #                      训练判断器                   #
        # ================================================================== #

        # 使用真实计算图计算BCE_LOSS
        # 损失的第二项总是1因为真实的标签是1
        oupouts = D(images)
        d_loss_real = loss(oupouts, real_labels)
        real_score = oupouts
        # 使用假的计算图计算BCE_LOSS
        # 损失的第一项总是0,因为假的标签是0
        z = torch.randn(bathc_size, latent_size).to(device)
        fake_images = G(z)
        oupout = D(fake_images)
        d_loss_fake = loss(oupout, fake_labels)
        fake_score = oupout

        # backprop 和 optimizer
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        # ================================================================== #
        #                        训练生成器                        #
        # ================================================================== #

        # 用假图计算损失
        z = torch.randn(bathc_size, latent_size).to(device)
        fake_images = G(z)
        oupouts = D(fake_images)

        g_loss = loss(oupout, real_labels)
        reset_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 200 == 0:
            print('Epoch[{}/{}],Step[{}/{}],d_loss{:.4f},g_loss{:.4f},D(x):{.2f},D(G(z)):{.2f}'
                  .format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(),
                          fake_score.mean().item()))

这是真实的图片↑:

※大家有兴趣可以跑跑试试,迭代200次。

下面是迭代50词的图片↓:

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-10-08 20:42:06  更:2022-10-08 20:45:23 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/25 20:45:13-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码