IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> PyTorch学习笔记:AutoEncoder自编码模型(基于Linear) -> 正文阅读

[人工智能]PyTorch学习笔记:AutoEncoder自编码模型(基于Linear)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import hiddenlayer as hl
from sklearn.manifold import TSNE
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report, accuracy_score
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

from AutoEncoder import AutoEncoder

if __name__ == '__main__':
    # 使用手写体数据,准备训练数据集
    # 数据集MNIST被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)
    train_data = MNIST(
        root="./data/MNIST",  # 数据的路径
        train=True,  # 只使用训练数据集
        transform=transforms.ToTensor(),
        download=False
    )
    # 将图像数据转化为向量数据
    train_data_x = train_data.data.type(torch.FloatTensor) / 255.0
    train_data_x = train_data_x.reshape(train_data_x.shape[0], -1)
    train_data_y = train_data.targets
    # 自编码器是无监督,训练网络不需要标签数据y,只需要特征张量x,故此处不需要将x和y整合到一起作为数据集
    # train_data = Data.TensorDataset(train_data_x, train_data_y)
    # 定义一个数据加载器
    train_loader = Data.DataLoader(
        dataset=train_data_x,  # 使用的数据集※
        batch_size=64,  # 批处理样本大小
        shuffle=True,  # 每次迭代前打乱数据
        num_workers=2,  # 使用两个进程
    )
    test_data = MNIST(
        root="./data/MNIST",
        train=False,  # 使用测试集
        transform=transforms.ToTensor(),
        download=False
    )
    test_data_x = test_data.data.type(torch.FloatTensor) / 255.0
    test_data_x = test_data_x.reshape(test_data_x.shape[0], -1)
    test_data_y = test_data.targets

    print("train_data", train_data_x.shape)
    print("test_data", test_data_x.shape)

    # 可视化训练数据集中一个batch的图像内容
    # for step, b_x in enumerate(train_loader):
    #     if step > 0:
    #         break
    #
    # # 网格化,将多幅图像拼在一起;数据为[batch,channel,height,width]形式
    # im = make_grid(b_x.reshape((-1, 1, 28, 28)))
    # im = im.data.numpy().transpose((1, 2, 0))
    # plt.figure()
    # plt.imshow(im)
    # plt.axis("off")
    # plt.show()

    AEmodel = AutoEncoder()
    print(AEmodel)

    # 自编码网络的训练
    # 使用Adam优化器;损失函数选择MSELoss(均方根误差损失),因为AE需重构出原始的手写体数据,所以看作回归问题,即与原始图像的误差越小越好
    # 定义优化器
    optimizer = torch.optim.Adam(AEmodel.parameters(), lr=0.003)  # 学习率
    loss_func = nn.MSELoss()
    # 记录训练过程的指标
    history1 = hl.History()
    # 使用Canvas进行可视化
    canvas1 = hl.Canvas()
    train_num = 0
    val_num = 0
    # 对模型进行迭代训练,对所有的数据训练epoch轮
    for epoch in range(10):
        train_loss_epoch = 0
        # 对训练数据的加载器进行迭代计算
        for step, b_x in enumerate(train_loader):
            # 使用每个batch进行训练模型
            _, output = AEmodel(b_x)
            # 均方根误差;b_x表示每次网络的输入数据,output表示经过自编码网络的输出内容
            loss = loss_func(output, b_x)
            #
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #
            train_loss_epoch += loss.item() * b_x.size(0)
            train_num = train_num + b_x.size(0)

        # 计算一个epoch的损失
        train_loss = train_loss_epoch / train_num
        print("epoch", epoch)
        print("train_loss", train_loss)
        # 保存每个epoch上的输出loss
        history1.log(epoch, train_loss=train_loss)
        # 可视化网络训练的过程
        # with canvas1:
        #     canvas1.draw_plot(history1["train_loss"])

    # 预测测试集前100张图像的输出
    # AEmodel.eval()  # 将模型设置为验证模式
    # _, test_decoder = AEmodel(test_data_x[0:100, :])
    # # 可视化原始的图像
    # plt.figure(figsize=(6, 6))
    # for ii in range(test_decoder.shape[0]):
    #     plt.subplot(10, 10, ii + 1)
    #     im = test_data_x[ii, :]
    #     im = im.data.numpy().reshape(28, 28)
    #     plt.imshow(im, cmap=plt.cm.gray)
    #     plt.axis("off")
    # plt.show()
    # # 可视化编码后的图像
    # plt.figure(figsize=(6, 6))
    # for ii in range(test_decoder.shape[0]):
    #     plt.subplot(10, 10, ii + 1)
    #     im = test_decoder[ii, :]
    #     im = im.data.numpy().reshape(28, 28)
    #     plt.imshow(im, cmap=plt.cm.gray)
    #     plt.axis("off")
    # plt.show()
    # 自编码网络得到的图像有些模糊,而且针对原始图像中的某些细节并不能很好地重构
    # 这是因为在网络中,自编码器部分最后一层只有3个神经元,将784维的数据压缩到三维,会损失大量的信息
    # 此处降到三维主要为了方便数据可视化,在实际情况中,可以使用较多的神经元,保留更丰富的信息。

    # 自编码后的特征训练集和测试集
    train_ae_x, _ = AEmodel(train_data_x)
    train_ae_x = train_ae_x.data.numpy()
    train_y = train_data_y.data.numpy()
    test_ae_x, _ = AEmodel(test_data_x)
    test_ae_x = test_ae_x.data.numpy()
    test_y = test_data_y.data.numpy()

    # PCA降维获得的训练集和测试集前3个主成分
    pcamodel = PCA(n_components=3, random_state=10)
    train_pca_x = pcamodel.fit_transform(train_data_x.data.numpy())
    test_pca_x = pcamodel.fit_transform(test_data_x.data.numpy())

    print(train_pca_x.shape)

    # 分别针对两种类型的数据使用相同的参数,建立 支持向量机 分类器
    # 先使用训练集对SVM分类器进行训练,然后利用测试集测试SVM的分类精度,并使用accuracy_score()和classification_report()函数输出分类器在测试集上的预测效果

    # 使用自编码数据建立分类器,训练和预测
    ae_svc = SVC(kernel="rbf", random_state=123)
    ae_svc.fit(train_ae_x, train_y)
    ae_svc_pre = ae_svc.predict(test_ae_x)
    print(classification_report(test_y, ae_svc_pre))
    print("AE+SVM模型精度", accuracy_score(test_y, ae_svc_pre))

    # 使用PCA降维数据建立分类器,训练和预测
    pca_svc = SVC(kernel="rbf", random_state=123)
    pca_svc.fit(train_pca_x, train_y)
    pca_svc_pre = pca_svc.predict(test_pca_x)
    print(classification_report(test_y, pca_svc_pre))
    print("PCA+SVM模型精度", accuracy_score(test_y, pca_svc_pre))

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-01-29 23:05:20  更:2022-01-29 23:07:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 20:32:19-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码