【机器学习】KNN算法实现手写板字迹识别
1. 前言
? 上篇博客通过KNN算法实现鸢尾花数据集分类,在博客最后对KNN算法是否适合于图像分类进行了讨论。本篇博客通过KNN算法实现手写板字迹识别,通过人机交互的模式验证KNN算法是否合适于图像分类。
2. 实验背景
? 在制作手写体数据集的过程中,常会出现训练数据和测试数据都是出自同一个人的笔迹,此时KNN模型对测试数据友好,但这其实并不符合实际情况。因此本次实验将不同人的笔迹作为训练集,为了增加人机交互性,我们实现了一个简单的手写板进行测试集的制作。
数据集类型 | 数据集来源 | 数据集个数 |
---|
训练集 | 收集不同人群的手写体,并转换成01矩阵的txt文件格式 | 1974 | 测试集 | 通过python实现手写板,将自己的手写体作为测试集进行预测 | 1 |
3. 测试过程
测试步骤:
1. 设计手写板
2. 制作测试数据
3. 加载训练数据
4. 通过KNN算法进行训练
5. 得到预测结果
3.1 手写板及测试数据的制作
基本思路:通过手写板写入内容,将其保存为图片(512 x 512)。再将图片转换为01矩阵(32 x 32)并保存为txt文件格式作为测试数据。
def draw(event,x,y,flags,param):
global ix,iy
if event==cv2.EVENT_LBUTTONDOWN:
drawing=True
ix,iy=x,y
elif event==cv2.EVENT_MOUSEMOVE:
if drawing==True:
cv2.circle(img,(x,y),30,(0,0,0),-1)
elif event==cv2.EVENT_LBUTTONUP:
drawing=False
if __name__ == "__main__":
img=np.zeros((512,512,3),np.uint8)
for i in range(512):
img[i,:]=255
cv2.namedWindow('image')
cv2.setMouseCallback('image',draw)
while(1):
cv2.imshow('image',img)
if cv2.waitKey(1) & 0xFF == ord(' '):
cv2.imwrite('1.jpg',img)
break
cv2.destroyAllWindows()
img1 = cv2.imread('1.jpg', cv2.IMREAD_GRAYSCALE)
res=cv2.resize(img1,(32,32),interpolation=cv2.INTER_CUBIC)
pic=[]
for i in range(32):
for j in range(32):
if res[i][j]<=200:
res[i][j]=1
else:
res[i][j]=0
pic.append(int(res[i][j]))
filename = 'out.txt'
with open(filename, 'w') as name:
for i in range(32*32):
name.write(str(pic[i]))
if (i+1) % 32 == 0:
name.write("\n")
3.2 加载训练数据并进行KNN模型搭建
? 这里的KNN算法与之前鸢尾花分类的算法一样,首先将测试数据按照训练数据维度进行转换,并通过欧式距离进行排序(鸢尾花分类通过计算每条数据之间的距离;字迹识别通过计算每张图片之间的距离),最后选出K个最近的数据并筛选频率最高的标签作为预测结果。
? 数据加载方式的差别有两个,第一是尺寸信息更丰富,需要将(32,32)维度转为(1,1024)维度方便欧氏距离的计算;第二是标签需要自己设置,这里采用常见的文件名分类法。
def img2vector(filename):
vector = np.zeros((1, 1024))
file = open(filename)
for i in range(32):
str = file.readline()
for j in range(32):
vector[0, 32*i+j] = int(str[j])
return vector
def KNN(Test, Train, labels, k):
dataSetSize = Train.shape[0]
distance = np.tile(Test, (dataSetSize, 1)) - Train
sqdistance = distance ** 2
sqdistances = sqdistance.sum(axis=1)
distances = sqdistances ** 0.5
sortedDistIndicies = distances.argsort()
result = []
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
result.append(voteIlabel)
print(result)
collection = Counter(result)
result = collection.most_common(1)
return result[0][0]
def main():
labels = []
Train_list = listdir('knn/digits/trainingDigits')
batch = len(Train_list)
Train = np.zeros((batch, 1024))
for i in range(batch):
name = Train_list[i]
filename = name.split('.')[0].split('_')[0]
labels.append(filename)
Train[i, :] = img2vector('knn/digits/trainingDigits/%s' % name)
Test = img2vector("out.txt")
result = KNN(Test, Train, labels, 3)
print(result)
3.3 结果预测
训练数据:
链接:https://pan.baidu.com/s/1Zh0rYwvovmm4drEOpjLS8A
提取码:mjll
全部代码:
import cv2
import numpy as np
from os import listdir
import operator
from collections import Counter
drawing=False
def img2vector(filename):
vector = np.zeros((1, 1024))
file = open(filename)
for i in range(32):
str = file.readline()
for j in range(32):
vector[0, 32*i+j] = int(str[j])
return vector
def KNN(Test, Train, labels, k):
dataSetSize = Train.shape[0]
distance = np.tile(Test, (dataSetSize, 1)) - Train
sqdistance = distance ** 2
sqdistances = sqdistance.sum(axis=1)
distances = sqdistances ** 0.5
sortedDistIndicies = distances.argsort()
result = []
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
result.append(voteIlabel)
print(result)
collection = Counter(result)
result = collection.most_common(1)
return result[0][0]
def main():
labels = []
Train_list = listdir('knn/digits/trainingDigits')
batch = len(Train_list)
Train = np.zeros((batch, 1024))
for i in range(batch):
name = Train_list[i]
filename = name.split('.')[0].split('_')[0]
labels.append(filename)
Train[i, :] = img2vector('knn/digits/trainingDigits/%s' % name)
Test = img2vector("out.txt")
result = KNN(Test, Train, labels, 3)
print("预测结果:", result)
def draw(event,x,y,flags,param):
global ix,iy,drawing
if event==cv2.EVENT_LBUTTONDOWN:
drawing=True
ix,iy=x,y
elif event==cv2.EVENT_MOUSEMOVE:
if drawing==True:
cv2.circle(img,(x,y),30,(0,0,0),-1)
elif event==cv2.EVENT_LBUTTONUP:
drawing=False
if __name__ == "__main__":
img=np.zeros((512,512,3),np.uint8)
for i in range(512):
img[i,:]=255
cv2.namedWindow('image')
cv2.setMouseCallback('image',draw)
while(1):
cv2.imshow('image',img)
if cv2.waitKey(1) & 0xFF == ord(' '):
cv2.imwrite('1.jpg',img)
break
cv2.destroyAllWindows()
img1 = cv2.imread('1.jpg', cv2.IMREAD_GRAYSCALE)
res=cv2.resize(img1,(32,32),interpolation=cv2.INTER_CUBIC)
pic=[]
for i in range(32):
for j in range(32):
if res[i][j]<=200:
res[i][j]=1
else:
res[i][j]=0
pic.append(int(res[i][j]))
filename = 'out.txt'
with open(filename, 'w') as name:
for i in range(32*32):
name.write(str(pic[i]))
if (i+1) % 32 == 0:
name.write("\n")
main()
4. 总结
? 大家有兴趣可以直接copy运行看看,预测结果还算可观。但相比深度学习方法精度略显不足。这是因为图片的语义信息是高级信息,无法通过常规的聚类以及距离进行语义判定。我们在做KNN实验时可以尝试输出测试数据与每个训练数据之间的距离,当样本足够时会发现距离相同的样本并不少且标签还不同,这就体现除了KNN算法的一个bug。
? 总结得出,KNN算法可以在简单的图像数据上进行测试(例如:灰度数据),但在RGB图像甚至更加高维的数据上并不适用。
|