IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于PaddlePaddle的LeNet神经网络来解决识别手写数字任务 -> 正文阅读

[人工智能]基于PaddlePaddle的LeNet神经网络来解决识别手写数字任务

识别手写数字是深度学习中比较基础的项目,简单来说,手写数字识别相当于是程序设计语言里的 Hello World 任务,用于对 0 ~ 9 的十类数字进行分类,即输入手写数字的图片,使得计算机能够识别出这个图片中的数字。听起来实现该功能可能距离我们比较遥远,但是只要你了解到paddlepaddle深度学习框架,相信你就知道实现这个任务是多么轻松加愉快!
在这里插入图片描述

1.环境搭建

在完成这个任务之前,我们需要安装并导入飞浆库。
在PyCharm编译器中,打开终端,在命令行输入如下命令

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle

下载并安装完毕后,我们可以创建python项目,导入并检验paddlepaddle是否安装成功,具体代码如下所示

import paddle    
print(paddle.__version__)

运行效果能够正确输出paddle的版本号即表示检验通过,即安装成功!
如果在环境搭建的过程中遇到TypeError: Descriptors cannot not be created directly的错误,解决方案请参考博客解决方案

2.数据集定义与加载

在百度飞浆中,MNIST是一个手写体数字的图片数据集,MNIST是深度学习领域标准、易用的成熟数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,50%来自人口普查局的工作人员。该数据集的收集目的是希望通过算法,实现对手写数字的识别。
在本任务中,我先后加载了 MNIST 训练集(mode=‘train’)和测试集(mode=‘test’),训练集用于训练模型,测试集用于评估模型效果。该数据集包含60000张训练图片,10000张测试图片,图片的分辨率为28*28,以及对应的分类标签文件。部分图片以及对应的分类标签如下图所示:
在这里插入图片描述
具体加载mnist数据集的代码如下所示:

import paddle
from paddle.vision.transforms import Normalize

transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 下载数据集并初始化 DataSet
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

# 打印数据集里图片数量
print('{} images in train_dataset, {} images in test_dataset'.format(len(train_dataset

Normalize函数的作用是使得没有可比性的数据变得具有可比性,同时又保持相比较的两个数据之间的相对关系。其实,查看源码就会发现,归一化就是将求出某一种数据的均值与方差,将每一个数据减去均值,然后在除去标准差后得到的标准化数据。

paddle.vision.datasets.MNIST(mode=‘train’, transform=transform)是为了加载MINST的60000行训练集数据并对数据进行归一化;paddle.vision.datasets.MNIST(mode=‘test’, transform=transform)是加载MINST的10000行测试集数据并进行归一化。

3.模型组网

飞桨的模型组网有多种方式,既可以直接使用飞桨内置的模型,也可以自定义组网。
『手写数字识别任务』比较简单,普通的神经网络就能达到很高的精度,在本任务中使用了飞桨内置的 LeNet 作为模型。飞桨在 paddle.vision.models 下内置了 CV(Computer Vision) 计算机视图领域的一些经典模型,LeNet 就是其中之一,调用很方便,只需一行代码即可完成 LeNet 的网络构建和初始化。num_classes 字段中定义分类的类别数,因为需要对 0 ~ 9 的十类数字进行分类,所以设置为 10。
另外通过 paddle.summary 可方便地打印网络的基础结构和参数信息。

paddle.summary(net, input_size=None, dtypes=None, input=None)
参数说明:
net (Layer) - 网络实例,必须是 Layer 的子类。
input_size (tuple|InputSpec|list[tuple|InputSpec) - 输入张量的大小。如果网络只有一个输入,那么该值需要设定为tuple或InputSpec。如果模型有多个输入。那么该值需要设定为list[tuple|InputSpec],包含每个输入的shape。默认值:None。
dtypes (str,可选) - 输入张量的数据类型,如果没有给定,默认使用 float32 类型。默认值:None。
input (tensor,可选) - 输入张量数据,如果给出 input ,那么 input_size 和 dtypes 的输入将被忽略。默认值:None。
# 模型组网并初始化网络
lenet = paddle.vision.models.LeNet(num_classes=10)

# 可视化模型组网结构和参数
paddle.summary(lenet,(1, 1, 28, 28))

4.模型训练与评估

模型训练需完成如下步骤:
使用 paddle.Model 封装模型。 将网络结构组合成可快速使用 飞桨高层 API 进行训练、评估、推理的实例,方便后续操作。
使用 paddle.Model.prepare 完成训练的配置准备工作。 包括损失函数、优化器和评价指标等。飞桨在 paddle.optimizer 下提供了优化器算法相关 API,在 paddle.nn Loss层 提供了损失函数相关 API,在 paddle.metric 下提供了评价指标相关 API。
使用 paddle.Model.fit 配置循环参数并启动训练。 配置参数包括指定训练的数据源 train_dataset、训练的批大小 batch_size、训练轮数 epochs 等,执行后将自动完成模型的训练循环。
因为是分类任务,这里损失函数使用常见的 CrossEntropyLoss (交叉熵损失函数),优化器使用 Adam,评价指标使用 Accuracy 来计算模型在训练集上的精度。
模型训练完成之后,调用 paddle.Model.evaluate使用预先定义的测试数据集,来评估训练好的模型效果,评估完成后将输出模型在测试集上的损失函数值 loss 和精度 acc。

# 封装模型,便于进行后续的训练、评估和推理
model = paddle.Model(lenet)

# 模型训练的配置准备,准备损失函数,优化器和评价指标
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()), 
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())

# 开始训练
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)

# 进行模型评估
model.evaluate(test_dataset, batch_size=64, verbose=1)

5.模型推理

执行模型推理时,可通过 paddle.Model.predict_batch 执行推理操作。
如下示例中,选择测试集中的一张图片 test_dataset[0] 作为输入,执行推理并打印结果,检验推理的结果与可视化图片是否一致。

# 从测试集中取出一张图片
img, label = test_dataset[0]
# 将图片shape从1*28*28变为1*1*28*28,增加一个batch维度,以匹配模型输入格式要求
img_batch = np.expand_dims(img.astype('float32'), axis=0)

# 执行推理并打印结果,此处predict_batch返回的是一个list,取出其中数据获得预测结果
out = model.predict_batch(img_batch)[0]
pred_label = out.argmax()
print('true label: {}, pred label: {}'.format(label[0], pred_label))
# 可视化图片
from matplotlib import pyplot as plt
plt.imshow(img[0])
plt.show()

6.全部代码展示

import paddle
import numpy as np
from paddle.vision.transforms import Normalize
# 对数据进行归一化
transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 下载数据集并初始化 DataSet
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

# 模型组网并初始化网络
lenet = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(lenet)

# 模型训练的配置准备,准备损失函数,优化器和评价指标
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()), 
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())
# 模型训练
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
# 模型评估
model.evaluate(test_dataset, batch_size=64, verbose=1)
# 从测试集中取出一张图片
img, label = test_dataset[0]
# 将图片shape从1*28*28变为1*1*28*28,增加一个batch维度,以匹配模型输入格式要求
img_batch = np.expand_dims(img.astype('float32'), axis=0)

# 执行推理并打印结果,此处predict_batch返回的是一个list,取出其中数据获得预测结果
out = model.predict_batch(img_batch)[0]
pred_label = out.argmax()
print('true label: {}, pred label: {}'.format(label[0], pred_label))
# 可视化图片
from matplotlib import pyplot as plt
plt.imshow(img[0])
plt.show()

运行结果展示:
在这里插入图片描述
在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-07-04 22:54:12  更:2022-07-04 22:54:43 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 19:59:22-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码