IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> MNIST手写数字识别 -> 正文阅读

[人工智能]MNIST手写数字识别

进入到研究生阶段了,从头学一下Pytorch,在这个小破站上记录一下自己的学习过程。
本文使用的是Pytorch来做手写数字的识别。

step0:先引入一些相关的包和库

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt


from utils import plot_image,plot_curve,one_hot

这里的utils是定义的一些辅助工具,包括loss下降的绘图函数和one_hot编码及图片显示的辅助函数。代码如下:
utils.py

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/10/26 下午4:47

import torch
from matplotlib import pyplot as plt

###loss下降
def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()



def plot_image(img,label,name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
        plt.title("{}:{}".format(name,label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(labels,depth=10):
    out = torch.zeros(labels.size(0),depth)
    idx = torch.LongTensor(labels).view(-1,1)
    out.scatter_(dim = 1, index = idx,value=1)
    return out

step1:加载数据
使用torch的DataLoader方法加载数据,MNIST数据集中的图片大小为28*28,比较小,batch_size可以设置大一点。

batch_size = 512
###step1  load dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data',train=True,download=True,
                               transform = torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   #数据归一化
                                   torchvision.transforms.Normalize(
                                       (0.1307,),(0.3081,))
                               ])),
    batch_size = batch_size,shuffle = True
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
                               transform = torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,),(0.3081,))
                               ])),
    batch_size = batch_size , shuffle = False
)

transforms.Compose方法将数据转为Tensor和做数据归一化,训练集中设置shuffle=True是将训练数据打乱.

step2:定义网络结构
使用简单的三层线性模型来做简单的识别。

class Net(nn.Module):

    def __init__(self):
        super(Net,self).__init__()

        self.fc1 = nn.Linear(28*28,256)
        self.fc2 = nn.Linear(256,64)
        self.fc3 = nn.Linear(64,10)

    def forward(self, x):
        #x:[batch_size,1,28,28]
        x = F.relu(self.fc1(x))

        x = F.relu(self.fc2(x))

        x = self.fc3(x)

        return x

step3:train
训练3个epoch

train_loss = []
net =Net()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

for epoch in range(3):
    for batch_idx,(x,y) in enumerate(train_loader):
        # x:[batch_size,1,28,28]
        #将x打平成二维的
        # y:batch_size
        x = x.view(x.size(0),28*28)
        out = net(x)
        y_onehot = one_hot(y)

        ##lose = mse(y,out)

        loss = F.mse_loss(out,y_onehot)

        optimizer.zero_grad() #梯度清零
        loss.backward() #计算梯度
        optimizer.step() #更新参数

        ##打印loss
        train_loss.append(loss.item())
        if batch_idx % 10 == 0:
            print(epoch,batch_idx,loss.item())
plot_curve(train_loss)

20211027 205211屏幕截图.png

step4:test
最后在验证集测试训练的准确率

total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)
    out = net(x)
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)

acc = total_correct / total_num
print("test acc:",acc)

20211027 205405屏幕截图.png

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-28 12:23:44  更:2021-10-28 12:24:11 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 9:00:43-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码