IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> [技术分享]LeNet卷积神经网络实现手写数字识别 -> 正文阅读

[人工智能][技术分享]LeNet卷积神经网络实现手写数字识别

手写数字识别是人工智能入门级的应用案例,同时也有很强的实用价值,例如在邮政系统中存在大量信件的邮编数字的识别。本文参考《深度学习工程师认证初级教程》中5.3.1节手写数字识别案例,采用LeNet实现,书中以Paddle1为主,思路可以参考,这里用Paddle2实现。

一、数据集

使用经典的MNIST数据集,数据集中已经分配好训练集6000张,测试集1000张,而且将图片大小做了规则化(28*28)和居中化。MINIST数据集的官网上介绍了各类传统和神经网络在这个数据上做分类识别的错误率,可见神经网络相对传统机器学习算法要优越很多。

可以看下MNIST数据集中的一条样本数据为:

Tensor(shape=[28, 28], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.07058824, 0.45098042, 0.77647066, 0.99607849, 0.58431375, 0.04705883, 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.10980393, 0.77647066, 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.50980395, 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.07450981, 0.83921576, 0.99607849, 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.69803923, 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.17647059, 0.89019614, 0.99215692, 0.99607849, 0.83529419, 0.54117650, 0.99215692, 0.99215692, 0.91764712, 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.06274510, 0.93725497, 0.99215692, 0.99215692, 0.57647061, 0.07450981, 0.40784317, 0.99215692, 0.99215692, 0.91764712, 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00784314, 0.51764709, 0.99215692, 0.94117653, 0.27843139, 0., 0.09411766, 0.80392164, 0.99215692, 0.99215692, 0.75294125, 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.07450981, 0.99215692, 0.99215692, 0.49803925, 0., 0., 0.63921571, 0.99215692, 0.99215692, 0.92941183, 0.23137257, 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.43137258, 0.99215692, 0.99215692, 0.14509805, 0.16470589, 0.54901963, 0.98823535, 0.99215692, 0.99215692, 0.50588238, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.40392160, 0.99215692, 0.99215692, 0.86666673, 0.95294124, 0.99607849, 0.99215692, 0.99215692, 0.69803923, 0.01176471, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.01960784, 0.51372552, 0.97647065, 0.99215692, 0.99215692, 0.99607849, 0.99215692, 0.82745105, 0.08627451, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.03137255, 0.80392164, 1., 0.99607849, 0.49411768, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.36078432, 0.99215692, 0.99607849, 0.85882360, 0.10588236, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.01176471, 0.85098046, 0.99215692, 0.87450987, 0.14509805, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.57647061, 0.99215692, 0.99215692, 0.26274511, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.06274510, 0.96470594, 0.99215692, 0.71764708, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.03529412, 0.76078439, 0.99215692, 0.98431379, 0.23921570, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.45882356, 0.99215692, 0.99215692, 0.49803925, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.10980393, 0.91764712, 0.99215692, 0.77647066, 0.05882353, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.03529412, 0.68235296, 0.99215692, 0.95294124, 0.08627451, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.07843138, 0.66274512, 0.96470594, 0.27843139, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

可以看出是一个28*28大小的张量(Tensor),其中大量的0,也就是黑色部分,有值的地方数据在0-1之间,是归一化后的图片,如果要画出来要先乘255。

在这里插入图片描述

二、配置说明

1、输入输出的配置

Paddle中自带了MNIST数据集,调用以下接口即可下载数据:

# 下载MINIST数据
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=ToTensor())
test_dataset =  paddle.vision.datasets.MNIST(mode='test', transform=ToTensor())

可以调用一下接口按批加载数据:

# 加载数据
BATCH_SIZE = 32

train_loader = paddle.io.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = paddle.io.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
for batch_id, data in enumerate(train_loader()):
    x_data = data[0]
    y_data = data[1]
    print(x_data.shape)
    print(y_data.shape)
    break
# output:
#[32, 1, 28, 28]
#[32, 1]

得到的train_loadertest_loader是迭代器,在for循环中调用一次即可以得到一批(32条)数据。

本文使用LeNet进行图片分类,LeNet的输入是28*28的图片,输出是通过softmax得到的10分类概率,其中概率最大的为预测的分类值。可通过简单的找最大值处理找到分类结果:

    out = paddle.argmax(out, axis=1).numpy()
    plt.figure()
    plt.title("predict:%d" %(out))
    plt.imshow(img)

在这里插入图片描述

2、网络的配置

使用LeNet可以构建一个类表示网络,也可以直接调用Paddle自带的LeNet。

自己构建的方法为:

import paddle
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.nn.functional as F

class LeNet5(nn.Layer):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2D(1,6,5, stride=1, padding=0)
        self.maxpool1= nn.MaxPool2D(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2D(6,16,5, stride=1, padding=0)
        self.maxpool2 = nn.MaxPool2D(2,2)
        self.flat = nn.Flatten()
        self.linear1 = nn.Linear(16*5*5, 120)
        self.linear2 = nn.Linear(120, 84)
        self.linear3 = nn.Linear(84, num_classes)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.maxpool2(x)
        x = self.flat(x)
        x = self.linear1(x)
        

        
lenet = LeNet5()
paddle.summary(lenet, (1, 1, 28, 28))

使用Paddle自带的LeNet网络可以用:

model = paddle.vision.models.LeNet() #加载paddle自带的常用模型,并预加载训练参数
paddle.summary(model, (1,1,28,28))

三、训练模型

配置好了数据和网络,下一步就该训练模型了,以得到能较好分类手写数字的模型。

import os

def train(model, epochs, train_loader, eval_loader, optim, metric_func, loss_func):
    train_losses = []
    eval_losses = []
    eval_acces = []

    for epoch in range(epochs):
        """ train"""    
        model.train() 
        train_loss = 0
        cnt = 0
        for input, label in train_loader:  # [n // B]
            out = model(input)
            loss = loss_func(out, label)
            train_loss += loss

            loss.backward()
            optim.step()
            optim.clear_grad()
            cnt += 1
        train_loss /= float(cnt)   # 单个epoch的平均loss,用于可视化
        train_losses.append(train_loss)

        """ evaluation"""
        model.eval()
        eval_loss = 0
        cnt = 0
        acc = 0
        with paddle.no_grad():
            metric_func.reset()
            for eval_x, eval_y in eval_loader: # n // B + 1 if n % B else 0
                outs = model(eval_x)
                loss = loss_func(outs, eval_y)
                eval_loss += loss

                correct = metric_func.compute(outs, eval_y)
                metric_func.update(correct)
                acc = metric_func.accumulate()
                cnt += 1
            eval_loss /= float(cnt)
            eval_losses.append(eval_loss)
            eval_acces.append(acc)
            metric_func.reset()
        
        print('---------epoch: %d, train_loss: %.3f, eval_loss: %.3f, eval_acc: %.3f-------' \
                %(epoch, train_loss, eval_loss, acc))
        # save
        if acc >= max(eval_acces):
            os.system("rm -f model_*")
            model_name = str("model_%d.pdparams" % epoch)
            paddle.save(model.state_dict(), "model.pdparams")

    return model, train_losses, eval_losses, eval_acces
# 训练模型

""" 训练相关超参数 """
epochs = 5
lr = 0.001

""" 优化方法和损失函数"""
optim = paddle.optimizer.Momentum(learning_rate=lr, parameters=model.parameters(), momentum=0.9)
loss_func = nn.loss.CrossEntropyLoss()
metric_func = paddle.metric.Accuracy()

""" 开始训练"""
model, train_losses, eval_losses, eval_acces = train(model, epochs, train_loader, test_loader, \
                optim, metric_func, loss_func)

四、应用模型

模型训练好后,使用模型开展推理的应用。

# 模型推理

""" 加载模型权重"""
infer_model = paddle.vision.models.LeNet()
state_dict_load = paddle.load('model.pdparams')
infer_model.set_state_dict(state_dict_load)
infer_model.eval()

for data in test_loader():
    data_array = paddle.Tensor.numpy(data[0][0][0])*255
    img = Image.fromarray(data_array.astype(np.uint8))
    out = infer_model(data[0][0].reshape((1, 1, 28, 28)))
    out = paddle.argmax(out, axis=1).numpy()
    plt.figure()
    plt.title("predict:%d" %(out))
    plt.imshow(img)
    break

五、发布模型

预览本项目运行的结果可以在百度AI Studio的LeNet手写数字识别项目NoteBook查看。在百度AI studio中可以通过“部署模型”将模型以API接口或者“体验馆”的方式被调用。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-03 17:04:52  更:2021-10-03 17:06:12 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 14:10:35-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码