IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> KNN算法的学习以及手写数字识别的实现 -> 正文阅读

[人工智能]KNN算法的学习以及手写数字识别的实现

前言

一、K-近邻算法是什么?

简而言之,k-近邻算法就是采用测量不同特征值之间的距离方法进行分类。

  • k-近邻算法的工作原理:
    存在一个样本数据集合(训练样本集),在样本集当中的每个数据都存在标签,即我们知道样本集中每一个数据与其所属分类的对应关系。在输入无标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后通过算法提取出样本集中特征最为相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法当中k的由来,通常k为不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
    举例说明:
    当k=3时,距离最近的3个样本为2个红色三角形与1个蓝色正方形,因此将它归类为红色三角形
    当k=5时,距离最近的5个样本为2个红色三角形与3个蓝色正方形,因此将它归类为蓝色正方形
    在这里插入图片描述

  • k-近邻算法的优缺点:
    优点:精度高,对异常值不敏感,无数据输入假定
    缺点:计算复杂度高,空间复杂度高
    适用数据范围:数值型和标称型

  • k-近邻算法的一般流程:
    (1)收集数据:可以使用任何方法
    (2)准备数据:距离计算所需要的数值,最好是结构化的数据格式
    (3)分析数据:可以使用任何方法
    (4)训练算法:此步骤不适用于k-近邻算法
    (5)测试算法:计算错误率
    (6)使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理

二、实验准备:使用python导入数据

  1. 先创建一个knn.py文件,然后编写一些通用函数
from numpy import * #导入numpy库
from os import listdir #导入os模块中的listdir,listdir()方法返回一个列表
import operator #导入operator模块
def createDataSet(): #创建数据集和标签
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group,labels

在这部分代码中我们导入了科学计算包numpy,运算符模块operator以及os模块当中的listdir方法。NumPy库总包含两种基本的数据类型:矩阵和数组,矩阵的使用类似Matlab,本实例用得多的是数组array。operator模块则在k-近邻算法执行排序操作时会使用到。listdir函数可以列出给定目录的文件名,在测试算法中会用到。

2.实施k-近邻算法
伪代码:
对未知类别属性的数据集中的每个点依次执行以下操作:
(1)计算已知类别数据集中的点与当前点之间的距离;
(2)按照距离递增次序排序;
(3)选取与当前点距离最小的k个点;
(4)确定前k个点所在类别的出现频率;
(5)返回前k个点出现频率最高的类别作为当前点的预测分类。

def classify0(inX,dataSet,labels,k): #inX:用于分类的输入 dataSet:输入的训练样本集 labels:标签向量 k:用于选择最近邻居的数目
    dataSetSize = dataSet.shape[0]   #shape函数是numpy.core.fromnumeric中的函数,它的功能是读取矩阵的长度,比如shape[0]就是读取矩阵第一维度的长度。
    diffMat = tile(inX,(dataSetSize,1)) - dataSet #tile函数 位于 python 模块 numpy .lib.shape_base中,他的功能是重复某个数组。. 比如 tile (A,n),功能是将数组A重复n次,构成一个新的数组
    sqDiffMat = diffMat**2 #平方
    sqDistances = sqDiffMat.sum(axis=1) # 普通sum默认参数为axis=0为普通相加,axis=1为一行的行向量相加
    distances = sqDistances**0.5 #开根号
    sortedDistIndicies = distances.argsort() # argsort返回数值从小到大的索引值(数组索引0,1,2,3)
    classCount={}    # 初始化classCount字典 选择距离最小的k个点
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]         # 根据排序结果的索引值返回靠近的前k个标签
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 # 各个标签出现频率
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) #将字典分解为元组列表并以itemgetter方法进行排序,最后返回发生频率最高的元素标签
    return sortedClassCount[0][0]

在此代码中,计算两个向量点xA和xB之间的距离时使用了欧式距离公式:d= ( x A 0 ? x B 0 ) 2 + ( x A 1 ? x B 1 ) 2 \sqrt {(xA0-xB0)^2+(xA1-xB1)^2} (xA0?xB0)2+(xA1?xB1)2 ?,在计算完所有点之间的距离后,就可以对数据按照从小到大的顺序排序,然后,确定前k个距离最小元素所在的主要分类,最后将classCount字典分解为元组列表,然后通过itemgetter方法按照第二个元素的次序对元组进行排序,此处的排序为逆序,最后返回发生频率最高的元素标签。

	group,labels = createDataSet()
    print(classify0([0,0],group,labels,3))

上述代码输出结果为B,当然若改变输入[0,0]为其他值,输出结果也会发生改变,如将输入[0,0]改为[1.0,1.1],则输出为A。

三、实验:实现手写识别系统

1.准备数据:将图像转换为测试向量
在该实验当中,我们首先要将书上所提供的实际图像转换为向量,在trainingDigits文件夹中包含了大约2000个例子,每个数字大约有200个样本,在testDigits文件夹中包含了大约900个测试数据。我们要先将一个32×32的二进制图像矩阵转换为1×1024的向量,这样才能使用之前写的分类器来处理数字图像信息。
首先编写img2vector函数,将图像转换为向量:该函数创建1*1024的numpy数组,然后打开给定文件,循环读出文件前32行,并将每行的头32个字符值存储在numpy数组中,最后返回数组。

def img2vector(filename):
    returnVect = zeros((1, 1024))    # 每个手写识别为32*32大小的二进制图像矩阵,转换为1*1024 numpy向量数组returenVect
    fr = open(filename)
    for i in range(32):    # 循环读出前32行
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])            # 将每行的32个字符值存储在numpy数组中
    return returnVect

2.测试算法:使用k-近邻算法识别手写数字
在将数据处理成分类器可识别的格式后,将数据输入到分类器中,检测分类器的执行效果。在该部分测试代码中,还需从os模块中导入listdir函数,通过该函数可以列出给定目类的文件名。

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('D:/machinelearning/machinelearninginaction/Ch02/trainingDigits')
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):    # 定义文件数 x 每个向量的训练集
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0]) # 解析文件名
        hwLabels.append(classNumStr)        # 存储类别
        trainingMat[i,:] = img2vector('D:/machinelearning/machinelearninginaction/Ch02/trainingDigits/%s' % fileNameStr)        # 访问第i个文件内的数据
    testFileList = listdir('D:/machinelearning/machinelearninginaction/Ch02/testDigits')    # 测试数据集
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])        # 从文件名中分离出数字
        vectorUnderTest = img2vector('D:/machinelearning/machinelearninginaction/Ch02/testDigits/%s' % fileNameStr)        # 访问第i个文件内的测试数据,不存储类 直接测试
        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
        print("the classifier came back with: %d,the real answer is: %d" % (classifierResult,classNumStr))
        if(classifierResult != classNumStr):errorCount += 1.0
    print("\nthe total number of errors is: %d" % errorCount)
    print("\nthe total error rate is: %f" % (errorCount/float(mTest)))

在handwritingClassTest()函数当中,我们先将trainingDigits目录中的文件内容存储在列表中,便可得到目录中含有多少文件,并将其存储在变量m当中。然后创建一个m行1024列的训练矩阵,该矩阵的每行数据存储一个图像。我们从文件名中解析出分类数字,然后将类代码存储在hwLabels向量当中,再使用img2vector函数载入图像,然后对testDigits目录中的文件执行相似操作,不过我们并不将该目录下的文件载入矩阵当中,而是使用classify0函数测试该目录下的每个文件。
测试结果:
在这里插入图片描述
通过测试结果可以看出错误率为1.1%。改变变量k的值、修改函数handwritingClassTest随机选取训练样本、改变训练样本的数目都会对k-近邻算法的错误率产生影响。
举例:当把k值修改为5后,错误率发生了变化
在这里插入图片描述

在这里插入图片描述

总结

通过本次实验,我学会了k-近邻算法的应用。k-近邻算法是分类数据最简单最有效的算法,k-近邻算法是给予实例的学习,在使用算法时我们必须有接近实际数据的训练样本数据。该算法精度高、对异常值不敏感,并且无数据输入假定,但该算法必须保存所有数据集,这样会使用大量的存储空间,空间复杂度高,并且还必须对数据集中的每个数据计算距离值,因此计算复杂度高。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-20 15:47:40  更:2021-09-20 15:49:42 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 16:43:26-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码