IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习 pytorch手写数字识别 MNIST数据集 解析+详细注释 -> 正文阅读

[人工智能]深度学习 pytorch手写数字识别 MNIST数据集 解析+详细注释


文件结构
①存放训练之后导出的模型;
②存放数据集;

在这里插入图片描述

1 模型构建

神经网络由对数据进行操作的层/模块(layers/modules)组成。torch.nn提供构建网络的所有blocks,
在PyTorch中的每个modules都继承了nn.Module,可以构建各种复杂的网络结构。
通过nn.Module定义神经网络,使用init初始化,对数据的所有操作都在forward()中实现

import torch 
import torch.nn as nn

# 1 创建网络模型 model.py

# 卷积神经网络(两个卷积层)

class ConvNet(nn.Module): # nn,neural network
    def __init__(self, num_classes=10): #0~9种类别
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), #输出尺寸:(n+2p-f)/s +1 = (28+4-5)/1 + 1 = 28,输入28输出还是28 ,1*28*28
            nn.BatchNorm2d(16), # 输出通道16,16*28*28
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)) #最大池化,做下采样,f=2,s=2相当于图像减半 #图片变模糊,保留原图片的特征,让训练参数减少。 #16*14*14
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), #输入16通道,输出32通道
            nn.BatchNorm2d(32), #32*14*14
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)) #32*7*7
        self.fc = nn.Linear(7*7*32, num_classes) #全连接层展开
        
    def forward(self, x): #前向传播
        out = self.layer1(x) #in bx1x28x28 out bx16x14x14
        out = self.layer2(out)#out bx32x7x7
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)#bx10
        return out

2 训练 train.py

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from model import ConvNet #加载网络模型

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #如果有cuda 就用GPU,没有就用CPU

# Hyper parameters 超参数
num_epochs = 5 #训练5轮
num_classes = 10 #类别10,数字0~9
batch_size = 100 #一次送入100个数据
learning_rate = 0.001 #梯度下降步长

# MNIST dataset  #数据集与加载torchvision中已经集成了,直接调用
train_dataset = torchvision.datasets.MNIST(root='./data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)#训练时,数据打乱

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False) #测试时,不打乱
# 模型初始化
model = ConvNet(num_classes).to(device)

# Loss and optimizer 损失和优化器
criterion = nn.CrossEntropyLoss() #交叉熵
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device) #images模型输入
        labels = labels.to(device) #labels用于计算loss
        
        # 前向传播,即网络如何根据输入得到输出的
        outputs = model(images)
        # loss计算
        loss = criterion(outputs, labels)
        
        # 反向传播与优化,反向传播算法的核心是代价函数对网络中参数(各层的权重和偏置)的偏导表达式和。
        optimizer.zero_grad() #梯度清零:重置模型参数的梯度。默认是累加,为了防止重复计数,在每次迭代时显式地将它们归零。
        loss.backward()#反向传播计算梯度:计算当前张量w.r.t图叶的梯度。
        optimizer.step()#参数更新:根据上面计算的梯度,调整参数
        
        if (i+1) % 100 == 0: #每个batch打印以此结果
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# 模型测试
model.eval()  
with torch.no_grad(): #禁用梯度计算:当我们训练了模型,只是想跑一下前向测试我们的数据
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1) #找出类别里最大的预测结果
        total += labels.size(0) #统计测试总数
        correct += (predicted == labels).sum().item() #统计预测正确总数

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# 模型导出
#PyTorch模型将学习到的参数存储在一个内部状态字典中,称为state_dict。使用torch.save()保存
torch.save(model.state_dict(), './checkpoints/model.ckpt')

在中端输入python train.py 或者运行按钮

在这里插入图片描述

3 测试 eval.py

文件默认2828的图片,
测试中用了两张不一样大小的图片,如果不是28
28,会resize为28*28,再识别。
在这里插入图片描述

import torch 
import torch.nn as nn

# 创建网络模型

import torch 
import torch.nn as nn
import numpy as np
import cv2
from model import ConvNet

#模型加载
model=ConvNet(10) #类别,数字0~9,10类
state_dict=torch.load('./checkpoints/model.ckpt')
model.load_state_dict(state_dict)
model.eval() #有BN和Dropout,测试时要加model.eval()with torch.no_grad():
    #数据加载
    image=cv2.imread('2.png',0) #读入灰度图,WH两个维度 28*28
    print(image.shape)
    if(image.shape != (28,28)): #如果不是28*28,resize为28*28
        image = cv2.resize(image,(28,28))
    image = np.expand_dims(image, 0) # 增加1个维度
    image = np.expand_dims(image, 0) # 再增加一个维度
    image=1.0-image.astype(np.float32)/255.0 #归一化到01,,因为测试图片是白底黑字,但训练集是黑底白字,做一个反色1.0-image
    print(image.shape) #1*1*28*28,batch,通道数,H,W
    image_t=torch.from_numpy(image) #转成Torch的张量
    outputs = model(image_t)
    _, predicted = torch.max(outputs.data, 1) # 输出0~9模型中分值最高的
    print(predicted.numpy())
    print("test end")

在中端输入python eval.py 或者运行按钮;
分别测试数字4,2
在这里插入图片描述
在这里插入图片描述

4 工程文件、数据集、源码下载

手写数字识别 MNIST数据集 解析+详细注释

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-11-05 00:28:48  更:2022-11-05 00:30:15 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/19 20:53:28-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码