手写数字识别是人工智能入门级的应用案例,同时也有很强的实用价值,例如在邮政系统中存在大量信件的邮编数字的识别。本文参考《深度学习工程师认证初级教程》中5.3.1节手写数字识别案例,采用LeNet实现,书中以Paddle1为主,思路可以参考,这里用Paddle2实现。
一、数据集
使用经典的MNIST数据集,数据集中已经分配好训练集6000张,测试集1000张,而且将图片大小做了规则化(28*28)和居中化。MINIST数据集的官网上介绍了各类传统和神经网络在这个数据上做分类识别的错误率,可见神经网络相对传统机器学习算法要优越很多。
可以看下MNIST数据集中的一条样本数据为:
Tensor(shape=[28, 28], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.07058824, 0.45098042, 0.77647066, 0.99607849, 0.58431375, 0.04705883, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.10980393, 0.77647066, 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.50980395, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.07450981, 0.83921576, 0.99607849, 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.69803923, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.17647059, 0.89019614, 0.99215692, 0.99607849, 0.83529419, 0.54117650, 0.99215692, 0.99215692, 0.91764712, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.06274510, 0.93725497, 0.99215692, 0.99215692, 0.57647061, 0.07450981, 0.40784317, 0.99215692, 0.99215692, 0.91764712, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00784314, 0.51764709, 0.99215692, 0.94117653, 0.27843139, 0., 0.09411766, 0.80392164, 0.99215692, 0.99215692, 0.75294125, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.07450981, 0.99215692, 0.99215692, 0.49803925, 0., 0., 0.63921571, 0.99215692, 0.99215692, 0.92941183, 0.23137257, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.43137258, 0.99215692, 0.99215692, 0.14509805, 0.16470589, 0.54901963, 0.98823535, 0.99215692, 0.99215692, 0.50588238, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.40392160, 0.99215692, 0.99215692, 0.86666673, 0.95294124, 0.99607849, 0.99215692, 0.99215692, 0.69803923, 0.01176471, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.01960784, 0.51372552, 0.97647065, 0.99215692, 0.99215692, 0.99607849, 0.99215692, 0.82745105, 0.08627451, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.03137255, 0.80392164, 1., 0.99607849, 0.49411768, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.36078432, 0.99215692, 0.99607849, 0.85882360, 0.10588236, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.01176471, 0.85098046, 0.99215692, 0.87450987, 0.14509805, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.57647061, 0.99215692, 0.99215692, 0.26274511, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.06274510, 0.96470594, 0.99215692, 0.71764708, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.03529412, 0.76078439, 0.99215692, 0.98431379, 0.23921570, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.45882356, 0.99215692, 0.99215692, 0.49803925, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.10980393, 0.91764712, 0.99215692, 0.77647066, 0.05882353, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.03529412, 0.68235296, 0.99215692, 0.95294124, 0.08627451, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.07843138, 0.66274512, 0.96470594, 0.27843139, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
可以看出是一个28*28大小的张量(Tensor),其中大量的0,也就是黑色部分,有值的地方数据在0-1之间,是归一化后的图片,如果要画出来要先乘255。
二、配置说明
1、输入输出的配置
Paddle中自带了MNIST数据集,调用以下接口即可下载数据:
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=ToTensor())
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=ToTensor())
可以调用一下接口按批加载数据:
BATCH_SIZE = 32
train_loader = paddle.io.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = paddle.io.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
print(x_data.shape)
print(y_data.shape)
break
得到的train_loader 和test_loader 是迭代器,在for 循环中调用一次即可以得到一批(32条)数据。
本文使用LeNet进行图片分类,LeNet的输入是28*28的图片,输出是通过softmax得到的10分类概率,其中概率最大的为预测的分类值。可通过简单的找最大值处理找到分类结果:
out = paddle.argmax(out, axis=1).numpy()
plt.figure()
plt.title("predict:%d" %(out))
plt.imshow(img)
2、网络的配置
使用LeNet可以构建一个类表示网络,也可以直接调用Paddle自带的LeNet。
自己构建的方法为:
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.nn.functional as F
class LeNet5(nn.Layer):
def __init__(self, num_classes=10):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2D(1,6,5, stride=1, padding=0)
self.maxpool1= nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = nn.Conv2D(6,16,5, stride=1, padding=0)
self.maxpool2 = nn.MaxPool2D(2,2)
self.flat = nn.Flatten()
self.linear1 = nn.Linear(16*5*5, 120)
self.linear2 = nn.Linear(120, 84)
self.linear3 = nn.Linear(84, num_classes)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.maxpool2(x)
x = self.flat(x)
x = self.linear1(x)
lenet = LeNet5()
paddle.summary(lenet, (1, 1, 28, 28))
使用Paddle自带的LeNet网络可以用:
model = paddle.vision.models.LeNet()
paddle.summary(model, (1,1,28,28))
三、训练模型
配置好了数据和网络,下一步就该训练模型了,以得到能较好分类手写数字的模型。
import os
def train(model, epochs, train_loader, eval_loader, optim, metric_func, loss_func):
train_losses = []
eval_losses = []
eval_acces = []
for epoch in range(epochs):
""" train"""
model.train()
train_loss = 0
cnt = 0
for input, label in train_loader:
out = model(input)
loss = loss_func(out, label)
train_loss += loss
loss.backward()
optim.step()
optim.clear_grad()
cnt += 1
train_loss /= float(cnt)
train_losses.append(train_loss)
""" evaluation"""
model.eval()
eval_loss = 0
cnt = 0
acc = 0
with paddle.no_grad():
metric_func.reset()
for eval_x, eval_y in eval_loader:
outs = model(eval_x)
loss = loss_func(outs, eval_y)
eval_loss += loss
correct = metric_func.compute(outs, eval_y)
metric_func.update(correct)
acc = metric_func.accumulate()
cnt += 1
eval_loss /= float(cnt)
eval_losses.append(eval_loss)
eval_acces.append(acc)
metric_func.reset()
print('---------epoch: %d, train_loss: %.3f, eval_loss: %.3f, eval_acc: %.3f-------' \
%(epoch, train_loss, eval_loss, acc))
if acc >= max(eval_acces):
os.system("rm -f model_*")
model_name = str("model_%d.pdparams" % epoch)
paddle.save(model.state_dict(), "model.pdparams")
return model, train_losses, eval_losses, eval_acces
""" 训练相关超参数 """
epochs = 5
lr = 0.001
""" 优化方法和损失函数"""
optim = paddle.optimizer.Momentum(learning_rate=lr, parameters=model.parameters(), momentum=0.9)
loss_func = nn.loss.CrossEntropyLoss()
metric_func = paddle.metric.Accuracy()
""" 开始训练"""
model, train_losses, eval_losses, eval_acces = train(model, epochs, train_loader, test_loader, \
optim, metric_func, loss_func)
四、应用模型
模型训练好后,使用模型开展推理的应用。
""" 加载模型权重"""
infer_model = paddle.vision.models.LeNet()
state_dict_load = paddle.load('model.pdparams')
infer_model.set_state_dict(state_dict_load)
infer_model.eval()
for data in test_loader():
data_array = paddle.Tensor.numpy(data[0][0][0])*255
img = Image.fromarray(data_array.astype(np.uint8))
out = infer_model(data[0][0].reshape((1, 1, 28, 28)))
out = paddle.argmax(out, axis=1).numpy()
plt.figure()
plt.title("predict:%d" %(out))
plt.imshow(img)
break
五、发布模型
预览本项目运行的结果可以在百度AI Studio的LeNet手写数字识别项目NoteBook查看。在百度AI studio中可以通过“部署模型”将模型以API接口或者“体验馆”的方式被调用。
|