目录
使用MNIST数据集,这是一组由美国高中生和人口调查局员工手写的70000个数字的图片。 该数据集分成训练集(前6万张图片)和测试集(最后1万张图片)
1.训练二元分类器
先简化问题,只尝试识别一个数字,比如数字5
from sklearn.datasets import fetch_openml # 我的sklearn版本为1.0.2
from sklearn.linear_model import SGDClassifier
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
# 下载mnist数据集,fetch_openml默认返回的是一个DataFrame,设置as_frame=False返回一个Bunch
# mnist.keys() 可查看所有的键
# data键,包含一个数组,每个实例为一行,每个特征为一列。
# target键,包含一个带有标记的数组
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
# 共有7万张图片,因为图片是28×28像素,所以每张图片有784个特征,每个特征代表了一个像素点的强度,从0(白色)到255(黑色)
x, y = mnist["data"], mnist["target"] # x.shape=(70000, 784),y.shape=(70000,)
y = y.astype(np.uint8) # 注意标签是字符,我们把y转换成整数
# 将数据集分为训练集和测试集
x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
# 我们可以看到第一张图片是5
some_digit = x[0]
some_digit_image = some_digit.reshape(28, 28) # 把长为784的一维数组转换成28x28的二维数组
# imshow用于生成图像,参数cmap用于设置图的Colormap,如果将当前图窗比作一幅简笔画,则cmap就代表颜料盘的配色
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.axis("off") # 关掉坐标轴
plt.show()
# 使用随机梯度下降(SGD)分类器,比如Scikit-Learn的SGDClassifier类
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
# max_iter最大迭代次数,random_state用于打乱数据,42表示一个随机数种子
sgd_clf = SGDClassifier(max_iter=1000, random_state=42)
sgd_clf.fit(x_train, y_train_5) # 在整个训练集上进行训练
# 模型预测
print(sgd_clf.predict([some_digit])) # 返回true
2.性能测量
①交叉验证(Cross-validation)
交叉验证就是将拿到的训练数据,分为训练和验证集。首先用训练集对模型进行训练,再利用验证集来测试该模型。 n折交叉验证:将数据分成n份,其中1份作为验证集。然后经过n次测试,每次都更换不同的验证集。得到n个结果,取平均值作为最终结果。
from sklearn.model_selection import cross_val_score
# sgd_clf是分类器,x_train表示训练实例,y_train_5表示每个训练实例对应的标签,cv=3表示3-折交叉验证,accuracy表示使用准确率作为结果的度量指标
cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring="accuracy")
结果
array([0.95035, 0.96035, 0.9604 ])
但是,如果现在我有一个模型,对每张图片都判定为“非5”?,考虑到所有图片中有约10%的图片是5,这种模型训练下来的准确率也能达到90%左右,但是该模型永远正确无法识别“5‘。 这说明准确率通常无法成为分类器的首要性能指标,特别是当你处理有偏数据集时(即某些类比其他类更为频繁)。
②混淆矩阵
评估分类器性能的更好方法是混淆矩阵,其总体思路就是统计A被识别成B的次数。
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
# 获取每次预测的结果,cv=3表示3-折交叉验证
y_train_pred = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3)
# 构造混淆矩阵,y_train_5包含目标类别,y_train_pred包含对应的预测类别
confusion_matrix(y_train_5, y_train_pred)
结果
array([[53892, 687],
[ 1891, 3530]])
# 真负类 假正类
# 假负类 真正类
# 真负类: 53892张“非5”图片被正确识别
# 假正类: 687张“非5”图片被错误地识别为“5”
# 假负类: 1891张“5”图片被错误地识别为“非5”
# 真正类:3530张“5”图片被正确识别
一个完美的分类器,它的副对角线的值都为0
# 我们可以假设我们都识别对了,打印出来看看
y_train_perfect_predictions = y_train_5
confusion_matrix(y_train_5, y_train_perfect_predictions)
# 结果
array([[54579, 0],
[ 0, 5421]])
?
|