IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> GAN-生成对抗网络-生成手写数字(基于pytorch) -> 正文阅读

[人工智能]GAN-生成对抗网络-生成手写数字(基于pytorch)

什么是GAN

GAN(Generative Adversarial Network),网络也如他的名字一样,有生成,有对抗,两个网络相互博弈。我们给两个网络起个名字,第一个网络用来生成数据命名为生成器(generator),另一个网络用来鉴别生成器生成的数据我们命名为鉴别器(discriminator)

GAN的训练

标准GAN的训练有三步

  • 用真实的训练数据训练鉴别器
  • 用生成的数据训练鉴别器
  • 训练生成器生成数据,并使鉴别器以为是真实数据

数据集

经典mnist数据集,典中典了,不放了,网上很多。

代码

多数代码来自《Pytorch生成对抗网络编程》人民邮电出版社
有些书上的方法我不是很习惯,也重构了很多,最后效果都差不多。
已修复模式崩坏等问题

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data as Data
from sklearn.preprocessing import OneHotEncoder
import scipy.io as scio
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt

mnist_dataset = pd.read_csv('mnist_train.csv', header=None).values
label = mnist_dataset[:, 0]
image_values = mnist_dataset[:, 1:] / 255.0

encoder = OneHotEncoder(sparse=False)  # sparse默认为True,返回稀疏矩阵
label = encoder.fit_transform(label.reshape(-1, 1))

train_t = torch.from_numpy(image_values.astype(np.float32))
label = torch.from_numpy(label.astype(np.float32))

train_data = Data.TensorDataset(train_t, label)

train_loader = Data.DataLoader(dataset=train_data,
                               batch_size=1,
                               shuffle=True)


def plot_num_image(index):
    plt.imshow(image_values[index].reshape(28, 28), cmap='gray')
    plt.title('label=' + str(label[index]))
    plt.show()


def generate_random(size):
    random_data = torch.rand(size)
    return random_data


def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data


# 构建分类器
class Discriminator(nn.Module):
    def __init__(self):
        # 初始化父类
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 300),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(300),

            nn.Linear(300, 30),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(30),

            nn.Linear(30, 1),
            nn.Sigmoid(),

        )

        self.loss_function = nn.BCELoss()

        # 创建优化器
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.01)

        self.counter = 0
        self.progress = []

    def forward(self, inputs):
        return self.model(inputs)

    def train(self, inputs, targets):
        outputs = self.forward(inputs)
        loss = self.loss_function(outputs, targets)

        # 每训练10此增加计数器
        self.counter += 1
        if self.counter % 10 == 0:
            self.progress.append(loss.item())

        if self.counter % 10000 == 0:
            print("counter = ", self.counter)

        # 清楚梯度,反向传播, 更新权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()


# 构建生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(100, 300),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(300),

            nn.Linear(300, 784),
            nn.Sigmoid(),


        )

        # 创建优化器
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.01)

        self.counter = 0
        self.progress = []

    def forward(self, inputs):
        return self.model(inputs)

    def train(self, D, inputs, targets):  # 用分类器的损失来训练生成
        g_output = self.forward(inputs)  # 生成器generator的输出

        d_output = D.forward(g_output)  # 分类器discriminator的输出

        loss = D.loss_function(d_output, targets)

        self.counter += 1
        if self.counter % 10 == 0:
            self.progress.append(loss.item())

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()


D = Discriminator()
G = Generator()

'''
for step, (b_x, b_y) in enumerate(train_loader):
    # 真实数据
    D.train(b_x[0], torch.FloatTensor([1.0]))

    # 生成数据
    D.train(generate_random(784), torch.FloatTensor([0.0]))

plt.plot(D.progress)  # loss很快就归0了
plt.show()

# 输出一个真是数据和生成数据
print('real_num:', D.forward(b_x[0]).item())
print('generate-num:', D.forward(generate_random(784)).item())
# 至此我们的鉴别器已经学会分类真实数据和我们随机生成的数据了


# 让生成器随机产生一个图像我们看看
output = G.forward(generate_random(1))
img = output.detach().numpy().reshape(28, 28)
plt.imshow(img, interpolation='none', cmap='gray')  # interpolation 差值方法
plt.show()
'''
for epoch in range(10):
    for step, (b_x, b_y) in enumerate(train_loader):
        # 真实数据
        D.train(b_x[0], torch.FloatTensor([1.0]))

        D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))

        G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))
    print('完成',epoch+1,'epoch','*************'*3)

# 我们看一下生成器和鉴别器的loss
plt.plot(D.progress, c='b', label='D-loss')
plt.plot(G.progress, c='r', label='G-loss')
plt.legend()
plt.savefig('loss.jpg')
plt.show()


# 此时的生成器已经经过训练,我们多生成几张看看
for i in range(6):
    output = G.forward(generate_random_seed(100))
    img = output.detach().numpy().reshape(28, 28)
    plt.subplot(2, 3, i+1)
    plt.imshow(img,cmap='gray')
plt.show()

我们生成几张图像看看:

for i in range(6):
    output = G.forward(generate_random_seed(100))
    img = output.detach().numpy().reshape(28, 28)
    plt.subplot(2, 3, i+1)
    plt.imshow(img, cmap='gray')
plt.show()

在这里插入图片描述
看着很像000038,非常好了,生成器并没有见过数字长什么样子,但是他学会了怎么写(生成)相似的图像。刚开始学GAN不久,至此我们的生成器也只是能随机生成图像,无法生成特定的数字。 还没想到怎么解决。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-13 21:47:25  更:2022-03-13 21:49:07 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 16:05:22-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码