IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 手写数字识别分类器 -> 正文阅读

[人工智能]手写数字识别分类器

目录

使用MNIST数据集,这是一组由美国高中生和人口调查局员工手写的70000个数字的图片。
该数据集分成训练集(前6万张图片)和测试集(最后1万张图片)

1.训练二元分类器

先简化问题,只尝试识别一个数字,比如数字5

from sklearn.datasets import fetch_openml  # 我的sklearn版本为1.0.2
from sklearn.linear_model import SGDClassifier
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 下载mnist数据集,fetch_openml默认返回的是一个DataFrame,设置as_frame=False返回一个Bunch
# mnist.keys() 可查看所有的键
# data键,包含一个数组,每个实例为一行,每个特征为一列。
# target键,包含一个带有标记的数组
mnist = fetch_openml('mnist_784', version=1, as_frame=False)

# 共有7万张图片,因为图片是28×28像素,所以每张图片有784个特征,每个特征代表了一个像素点的强度,从0(白色)到255(黑色)
x, y = mnist["data"], mnist["target"] # x.shape=(70000, 784),y.shape=(70000,)
y = y.astype(np.uint8) # 注意标签是字符,我们把y转换成整数
# 将数据集分为训练集和测试集
x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]

# 我们可以看到第一张图片是5
some_digit = x[0]
some_digit_image = some_digit.reshape(28, 28) # 把长为784的一维数组转换成28x28的二维数组
# imshow用于生成图像,参数cmap用于设置图的Colormap,如果将当前图窗比作一幅简笔画,则cmap就代表颜料盘的配色
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.axis("off") # 关掉坐标轴
plt.show()

# 使用随机梯度下降(SGD)分类器,比如Scikit-Learn的SGDClassifier类
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
# max_iter最大迭代次数,random_state用于打乱数据,42表示一个随机数种子
sgd_clf = SGDClassifier(max_iter=1000, random_state=42)
sgd_clf.fit(x_train, y_train_5)  # 在整个训练集上进行训练

# 模型预测
print(sgd_clf.predict([some_digit]))  # 返回true

2.性能测量

①交叉验证(Cross-validation)

交叉验证就是将拿到的训练数据,分为训练和验证集。首先用训练集对模型进行训练,再利用验证集来测试该模型。
n折交叉验证:将数据分成n份,其中1份作为验证集。然后经过n次测试,每次都更换不同的验证集。得到n个结果,取平均值作为最终结果。

from sklearn.model_selection import cross_val_score
# sgd_clf是分类器,x_train表示训练实例,y_train_5表示每个训练实例对应的标签,cv=3表示3-折交叉验证,accuracy表示使用准确率作为结果的度量指标
cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring="accuracy") 

结果

array([0.95035, 0.96035, 0.9604 ])

但是,如果现在我有一个模型,对每张图片都判定为“非5”?,考虑到所有图片中有约10%的图片是5,这种模型训练下来的准确率也能达到90%左右,但是该模型永远正确无法识别“5‘。
这说明准确率通常无法成为分类器的首要性能指标,特别是当你处理有偏数据集时(即某些类比其他类更为频繁)。

②混淆矩阵

评估分类器性能的更好方法是混淆矩阵,其总体思路就是统计A被识别成B的次数

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix

# 获取每次预测的结果,cv=3表示3-折交叉验证
y_train_pred = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3)
# 构造混淆矩阵,y_train_5包含目标类别,y_train_pred包含对应的预测类别
confusion_matrix(y_train_5, y_train_pred)

结果

array([[53892,   687],
       [ 1891,  3530]])
# 真负类 假正类
# 假负类 真正类
# 真负类: 53892张“非5”图片被正确识别
# 假正类: 687张“非5”图片被错误地识别为“5”
# 假负类: 1891张“5”图片被错误地识别为“非5”
# 真正类:3530张“5”图片被正确识别

一个完美的分类器,它的副对角线的值都为0

# 我们可以假设我们都识别对了,打印出来看看
y_train_perfect_predictions = y_train_5
confusion_matrix(y_train_5, y_train_perfect_predictions)
# 结果
array([[54579,     0],
       [    0,  5421]])

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-22 18:36:57  更:2022-04-22 18:40:29 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 17:20:16-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码