import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn import datasets
from sklearn.model_selection import cross_val_score
导包加载数据
X,y = datasets.load_iris(True)
X.shape
(150, 4)
150**0.5
12.24744871391589
cross_val_score交叉验证筛选最合适的参数
knn = KNeighborsClassifier()
score = cross_val_score(knn,X,y,scoring='accuracy',cv= 10)
score.mean()
0.96666666666666679
应用cross_val_score筛选最合适的邻居数量
errors = []
for k in range(1,14):
knn = KNeighborsClassifier(n_neighbors=k)
score = cross_val_score(knn,X,y,scoring='accuracy',cv=6).mean()
errors.append(1-score)
import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(np.arange(1,14),errors)
[<matplotlib.lines.Line2D at 0x27f2d147748>]
weights = ['uniform','distance']
for w in weights:
knn = KNeighborsClassifier(n_neighbors= 11, weights= w)
print(w,cross_val_score(knn,X,y,scoring='accuracy',cv = 6).mean())
uniform 0.980709876543
distance 0.979938271605
多参数组合使用cross_val_score筛选最合适的组合参数
模型如何调参的 ,参数调节
result = {}
for k in range(1,14):
for w in weights:
knn = KNeighborsClassifier(n_neighbors=k,weights=w)
sm = cross_val_score(knn,X,y,scoring='accuracy',cv=6).mean()
result[w+str(k)] =sm
result
{'distance1': 0.95910493827160492,
'distance10': 0.97299382716049376,
'distance11': 0.97993827160493829,
'distance12': 0.97993827160493829,
'distance13': 0.97299382716049376,
'distance2': 0.95910493827160492,
'distance3': 0.96604938271604934,
'distance4': 0.96604938271604934,
'distance5': 0.96604938271604934,
'distance6': 0.97299382716049376,
'distance7': 0.97299382716049376,
'distance8': 0.97299382716049376,
'distance9': 0.97299382716049376,
'uniform1': 0.95910493827160492,
'uniform10': 0.97299382716049376,
'uniform11': 0.98070987654320996,
'uniform12': 0.97376543209876543,
'uniform13': 0.97376543209876543,
'uniform2': 0.93904320987654311,
'uniform3': 0.96604938271604934,
'uniform4': 0.96604938271604934,
'uniform5': 0.96604938271604934,
'uniform6': 0.97299382716049376,
'uniform7': 0.97299382716049376,
'uniform8': 0.95910493827160492,
'uniform9': 0.96604938271604934}
np.array(list(result.values())).argmax()
20
list(result)[20]
'uniform11'
|