说明
《机器学习实战》第20-28页的代码,含部分注释
代码
from numpy import *
import matplotlib.pyplot as plt
def classify(inX,dataSet,labels,k):
rows = dataSet.shape[0]
inX = tile(inX,(rows,1))
inX = inX-dataSet
inX = (inX**2).sum(axis=1)
distance = inX**0.5
index = distance.argsort()
dict = {}
for i in range(k):
pos = labels[index[i]]
dict[pos] = dict.get(pos,0)+1
Dict = sorted(dict.items(),key=lambda x:x[1],reverse=True)
return Dict[0][0]
def filetoInfo(filename):
fr = open(filename)
lines = len(fr.readlines())
mat = zeros((lines,3))
labels = []
index = 0
fr = open(filename)
for line in fr.readlines():
line = line.strip().split('\t')
mat[index,:] = line[0:3]
labels.append(int(line[-1]))
index += 1
return array(mat),array(labels)
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
m = dataSet.shape[0]
up = (dataSet - tile(minVals,(m,1)))
down = tile(ranges,(m,1))
normDataSet = up/down
return normDataSet,ranges,minVals
def datingClassTest():
mat,labels = filetoInfo('datingTestSet2.txt')
m = mat.shape[0]
testNum = int(m*0.1)
errcnt = 0
for i in range(testNum):
result = classify(mat[i,:],mat[testNum:m,:],labels[testNum:m],10)
if result != labels[i]:
errcnt += 1
print("error cnt: %d error rate: %f" % (errcnt,errcnt/float(testNum)))
def datingClassTest():
testRate = 0.5
mat,labels = filetoInfo('datingTestSet2.txt')
normMat,ranges,minVals = autoNorm(mat)
m = normMat.shape[0]
test = int(m*testRate)
errCnt = 0
for i in range(test):
result = classify(normMat[i,:],normMat[test:m,:],labels[test:m],3)
print("result: %d real: %d" % (result,labels[i]))
if result!=labels[i]:
errCnt += 1
print("err rate: %f" % (errCnt/float(test)))
print("err count: %d" % (errCnt))
datingClassTest()
注意
dataSet的数据格式要保持一致
总结
kNN算法逻辑较为简单,大家可以利用debug模式逐步查看每个变量,就能逐渐理解代码前后的逻辑
|