IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 用PyTorch进行手写数字识别 -> 正文阅读

[人工智能]用PyTorch进行手写数字识别

数据准备

torch.utils.data.Datasets是PyTorch用来表示数据集的类,它是用PyTorch进行手写数字识别的关键。
下面是加载mnist数据集并对其可视化的代码

from torchvision import datasets
from torchvision.datasets import MNIST
from matplotlib import pyplot as plt
import numpy as np

mnist = datasets.MNIST(root='C:/Users/WSY/Desktop/用pytorch进行手写数字识别',
                       train=True,download=True)

for i, j in enumerate(np.random.randint(0,len(mnist),(10,))):
    data, label = mnist[j]
    plt.subplot(2,5,i+1)
    plt.imshow(data)

这段代码中,首先实例化了Datasets对象mnist ,datasets.MNIST能够自动下载数据集保存到本地磁盘的root位置(自己设置),参数train默认为True,用于控制加载的数据集是训练集还是测试集。for循环中,使用len(mnist)调用了__len__方法,使用mnist[j]调用了__getitem__方法(在我们自己建立数据集时,需要继承Dataset,并且覆写__len__和__getitem__两个方法)。最后两行代码绘制了MNIST手写数字数据集。
运行代码,查看数据集的部分数据可视化结果
在这里插入图片描述
在变量浏览器中可以看到有关变量的内容,其中mnist就是实例化的数据集对象,它包含6000张图像内容,在for循环过程中,data读取的是28×28的Image类型图像,label是该图像对应的标签,也就是图像上表示的数字。
在这里插入图片描述
由于数据预处理是非常重要的步骤,所以PyTorch提供了torchvision.transforms用于处理数据及数据增强。在这里我们使用了torchvision.transforms.ToTensor将PIL Image或者numpy.ndarray类型的数据转换为Tensor,并且它会将数据从【0,255】映射到【0,1】。torchvision.transforms.Normalize会将数据标准化,加速模型在训练中的收敛速率。在使用中,可利用torchvision.transforms.Compose将多个transforms组合在一起,被包含的transforms会顺序执行。
数据流程处理准备完善后开始读取用于训练的数据,torch.utils.data.DataLoader提供了迭代数据、随机抽取数据、批量化数据。
下面的代码中实例化了mnist对象,定义了transforms方法trans并将其使用到数据集的实例化过程中,用于对数据集进行处理。然后定义了函数imshow,函数体的的第一行代码将数据从标准化的数据中恢复,第二行代码将Tensor类型转换为ndarray,这样才可以用matplotlib绘制出来,绘制的结果如下图所示。函数体最后一行使用transpose函数将矩阵维度从(C,W,H)转换为(W,H,C),这样才符合正常的通道顺序。

import torchvision
from torchvision import datasets,transforms
from torchvision.datasets import MNIST
from matplotlib import pyplot as plt
import numpy as np
import torch

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

mnist = datasets.MNIST(root='C:/Users/WSY/Desktop/用pytorch进行手写数字识别',
                       train=True,download=True,transform=trans)

def imshow(img):
    img = img * 0.3081 + 0.1307
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))

dataloader = torch.utils.data.DataLoader(mnist, batch_size=4, shuffle=True, num_workers=0)
images, labels = next(iter(dataloader))

imshow(torchvision.utils.make_grid(images))

运行代码得到的变量情况及绘制结果如下图
在这里插入图片描述

网络模型

下面构建用于识别手写数字的神经网络模型

import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.inputlayer = nn.Sequential(nn.Linear(28*28, 256),
                                        nn.ReLU(),
                                        nn.Dropout(0.2))
        self.hiddenlayer = nn.Sequential(nn.Linear(256, 256),
                                         nn.ReLU(),
                                         nn.Dropout(0.2))
        self.outlayer = nn.Sequential(nn.Linear(256, 10))

    def forward(self, x):
        #将输入图像拉伸为一维向量
        x = x.view(x.size(0), -1)

        x = self.inputlayer(x)
        x = self.hiddenlayer(x)
        x = self.outlayer(x)
        return x

通过nn.Module对象看到其网络结构,如下

In [12]: print(MLP())
MLP(
  (inputlayer): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
  )
  (hiddenlayer): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
  )
  (outlayer): Sequential(
    (0): Linear(in_features=256, out_features=10, bias=True)
  )
)

完整实现

准备好数据和模型后,就可以进行训练模型了。下面分别定义了数据处理和加载流程、模型、优化器、损失函数以及用准确率评估模型能力。训练过程持续10个epoch

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import optim 
from tqdm import tqdm
from torchvision import datasets,transforms
import matplotlib.pylab as plt

#数据处理和加载
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
mnist_train = datasets.MNIST(root='C:/Users/WSY/Desktop/用pytorch进行手写数字识别',
                       train=True,download=True,transform=trans)
mnist_val = datasets.MNIST(root='C:/Users/WSY/Desktop/用pytorch进行手写数字识别',
                       train=False,download=True,transform=trans)

trainloader = DataLoader(mnist_train, batch_size=16, shuffle=True, num_workers=0)
valloader = DataLoader(mnist_val, batch_size=16, shuffle=True, num_workers=0)

#定义网络模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.inputlayer = nn.Sequential(nn.Linear(28*28, 256),
                                        nn.ReLU(),
                                        nn.Dropout(0.2))
        self.hiddenlayer = nn.Sequential(nn.Linear(256, 256),
                                         nn.ReLU(),
                                         nn.Dropout(0.2))
        self.outlayer = nn.Sequential(nn.Linear(256, 10))

    def forward(self, x):
        #将输入图像拉伸为一维向量
        x = x.view(x.size(0), -1)

        x = self.inputlayer(x)
        x = self.hiddenlayer(x)
        x = self.outlayer(x)
        return x

model = MLP()

#优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

#损失函数
celoss = nn.CrossEntropyLoss()

#计算准确率
def accuracy(pred, target):
    pred_label = torch.argmax(pred, 1)
    correct = sum(pred_label == target).to(torch.float)

    return correct, len(pred)

acc = {'train':[], 'val':[]}
loss_all = {'train':[], 'val':[]}

for epoch in tqdm(range(10)):
    #设置为验证模式
    model.eval()
    numer_val, denumer_val, loss_tr = 0., 0., 0.
    with torch.no_grad():
        for data, target in valloader:
            output = model(data)
            loss = celoss(output, target)
            loss_tr += loss.data

            num, denum = accuracy(output, target)
            numer_val += num
            denumer_val += denum
    
    #设置为训练模式
    model.train()
    numer_tr, denumer_tr, loss_val = 0., 0., 0.
    for data, target in trainloader:
        optimizer.zero_grad()
        output = model(data)
        loss = celoss(output, target)
        loss_val += loss.data
        loss.backward()
        optimizer.step()
        num, denum = accuracy(output, target)
        numer_tr += num
        denumer_tr += denum

    loss_all['train'].append(loss_tr/len(trainloader))
    loss_all['val'].append(loss_val/len(valloader))
    acc['train'].append(numer_tr/denumer_tr)
    acc['val'].append(numer_val/denumer_val)

运行完成,如下
在这里插入图片描述
设计到的变量情况如下
在这里插入图片描述
在这里插入图片描述
查看模型训练迭代过程的损失图像

plt.plot(loss_all['train'])
plt.plot(loss_all['val'])

在这里插入图片描述
查看训练迭代过程的准确率图

plt.plot(acc['train'])
plt.plot(acc['val'])

在这里插入图片描述
O V E R !

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-08 22:28:33  更:2022-03-08 22:31:41 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 2:23:23-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码