IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习笔记_搭建一个简单网络(完整版)_手写数字识别MNIST -> 正文阅读

[人工智能]深度学习笔记_搭建一个简单网络(完整版)_手写数字识别MNIST

目录

1. 加载MNIST的train数据和test数据

2. 定义神经网络

3. 使用定义的网络进行训练

4. 使用测试集,计算预测精度

5. 辅助工具函数


1. 加载MNIST的train数据和test数据

import torch
import torchvision  # 处理图像视频, 包含一些常用的数据集、模型、转换函数等等
from torch import nn, optim
from torch.nn import functional as F

from matplotlib import pyplot as plt
from utils import plot_curve, plot_image, one_hot

batch_size = 512
# step1. 加载数据集
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data_john', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data_john/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)


# # test: 显示下训练数据集的前6张图像及对应的标签
# x,y = next(iter(train_loader))
# print(x.shape, y.shape, x.min(), x.max())
# plot_image(x,y,"image_gt")

显示下训练数据集的前6张图像及对应的标签:

2. 定义神经网络

这里采用 3个线性层,做简单示范。

# step2. 定义神经网络
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()

        # 定义三个线性层 y = wx+b
        # 输入X的size是:[batch_size, 28*28=784]
        # y = w1 * x + b1, 例如:参数数量:w1.size是[256,784] (一张图像x的size是[784,1]), b.size是[256]
        self.fc1 = nn.Linear(28 * 28, 256)  # 28*28是输入图像的大小,256是自定义的中间层大小
        # y = w2 * x + b2
        self.fc2 = nn.Linear(256, 64)       # 中间层数的结果,本层输入层数取决于上一层的输出层数,本层输出决定了下一层的输入层数。
        # y = w3 * x + b3
        self.fc3 = nn.Linear(64, 10)        # 10是要求的输出分类层数

    def forward(self, x):
        x = F.relu(self.fc1(x))             # 使用激活函数,增加非线性
        x = F.relu(self.fc2(x))
        x = self.fc3(x)                     # 最后一层根据网络结果输出
        return x

3. 使用定义的网络进行训练

# step3. 开始训练
net = MyNet()
# 定义梯度下降方式
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []

for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader): # train_loader中有 n 个 (x,y), 每一个 x和 y包含 batch_size张图像,所以总的图像数量是= batch_idx * batch_size
        x = x.view(x.size(0), 28 * 28)  # view 等价于reshape, [512,1,28*28] => [512,28*28]
        out = net(x)                    # [batch_size,10]
        y_onehot = one_hot(y)           # 如将 [512,] 变成 [512,10], 原来一维数组中对应m行的值n,对应新的二维数组m行的第n列设置为1
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())
# 打印训练过程中的loss结果
plot_curve(train_loss)

loss的下降结果:

4. 使用测试集,计算预测精度

# step4. 进行测试,计算预测精度
total_correct = 0
for x, y in test_loader:
    x = x.view(x.size(0), 28 * 28)
    out = net(x)
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print("acc: ", acc)

# 可视化部分预测结果
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28)) # 二维 [batch_size, 10]
pred = out.argmax(dim=1)              # 一维 [batch_size,]
plot_image(x, pred, 'image_predict')

预测精度:

可视化部分预测结果:

5. 辅助工具函数

定义到uitls.py文件中:

用于显示图像,打印一维数组,one hot操作

import torch
from matplotlib import pyplot as plt

# 绘制一维数据图
def plot_curve(data):
    fig = plt.figure()  # 定义一张图纸
    plt.plot(range(len(data)), data, color="blue")  # 绘制一维数组
    plt.legend(["value"], loc="upper right")        # 添加图例,即数据说明标签
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()


def plot_image(img, label, name):
    '''
    :param img:     比如:torch.Size([batch_size=512, 1, 28, 28])
    :param label:   比如:torch.Size([512])
    :param name:    string
    '''
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1) # 2*3个小图像
        plt.tight_layout()
        plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap="gray", interpolation="none")  # 图像进行正则化,然后显示出来
        plt.title("{}: {}".format(name, label[i].item()))  # 显示每一个plot的标题
        plt.xticks([])  # 设置x轴的刻度标签为空,即不显示刻度
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth) # 定义一个 [batch_size, 10]大小的矩阵
    idx = torch.LongTensor(label).view(-1, 1)  # 把 label reshape成 [batch_size, 1]尺寸的2维tensor
    out.scatter_(dim=1, index=idx, value=1)  # 改变 out的第dim=1维度的数据,out中值被改变值的索引,是index中对应的值, 填充的值是 1,
                                             # 如第2个样本是“6”,则第1行第5列(矩阵索引从0开始)填充为1.
    return out

# if __name__ == '__main__':
#     data = [1,2,3,4,5,4,3,2,5,6,8]
#     plot_curve(data)

参考:

深度学习入门_哔哩哔哩_bilibili

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-11 12:42:26  更:2021-11-11 12:43:20 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 6:58:53-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码