I. KNN基本原理
关于KNN的基本原理可以参考我之前写的一篇文章:最简单的分类算法之一:KNN(原理解析+代码实现)
II. 数据处理
导入torchvision.datasets.MNIST数据集:
def load_data():
dataset_train = torchvision.datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor())
dataset_test = torchvision.datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())
data_train = dataset_train.data
X_train = data_train.numpy()
X_test = dataset_test.data.numpy()
X_train = np.reshape(X_train, (60000, 784))
X_test = np.reshape(X_test, (10000, 784))
Y_train = dataset_train.targets.numpy()
Y_test = dataset_test.targets.numpy()
return X_train, Y_train, X_test, Y_test
训练集中含有60000条数据,测试集中含有10000条数据。任意输出一条数据:
print(X_train[0], Y_train[0])
结果为:
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 3 18 18 18 126 136 175 26 166 255
247 127 0 0 0 0 0 0 0 0 0 0 0 0 30 36 94 154
170 253 253 253 253 253 225 172...] 5
x是一个维度为784的一维数组,y是标签。
III. 手写KNN
为了减少计算量,训练集长度改为30000,测试集长度改为1000。
def get_distance(x1, x2):
return np.linalg.norm(x1 - x2)
def get_vec(K, x, train_x, train_y):
res = []
for i in range(len(train_x)):
dis = get_distance(x, train_x[i])
res.append([dis, train_y[i]])
res = sorted(res, key=(lambda t: t[0]))
return res[:K]
def knn(K):
train_x, train_y, test_x, test_y = load_data()
train_x, train_y, test_x, test_y = train_x[:30000], train_y[:30000], test_x[:1000], test_y[:1000]
cnt = 0
for i in range(len(test_x)):
x = test_x[i]
y = test_y[i]
vec = get_vec(K, x, train_x, train_y)
weight = []
sum_distance = 0.0
for j in range(K):
sum_distance += vec[j][0]
for j in range(K):
weight.append([1 - vec[j][0] / sum_distance, vec[j][1]])
num = []
for j in range(K):
num.append(weight[j][1])
num = list(set(num))
final_res = []
for j in range(len(num)):
res = 0.0
for k in range(len(weight)):
if weight[k][1] == num[j]:
res += weight[k][0]
final_res.append([res, num[j]])
final_res = sorted(final_res, key=(lambda e: e[0]), reverse=True)
if y == final_res[0][1]:
cnt = cnt + 1
print(y, final_res[0][1])
print('accuracy:', cnt / len(test_x))
if __name__ == '__main__':
K = 10
knn(K)
如果想要提升精度,可以增加训练集容量。
IV. sklearn.KNeighborsClassifier
采用sklearn中的KNeighborsClassifier对数据进行训练和测试,此时训练集和测试集都包括完整的数据:
if __name__ == '__main__':
K = 10
train_x, train_y, test_x, test_y = load_data()
knn = KNeighborsClassifier(n_neighbors=K)
knn.fit(train_x, train_y)
acc = knn.score(test_x, test_y)
print('accuracy:', acc)
结果:
V. 完整代码
"""
@Time : 2022/1/7 21:28
@Author :KI
@File :knn-mnist.py
@Motto:Hungry And Humble
"""
import numpy as np
import torchvision
import torchvision.transforms as transforms
from sklearn.neighbors import KNeighborsClassifier
def load_data():
dataset_train = torchvision.datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor())
dataset_test = torchvision.datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())
data_train = dataset_train.data
X_train = data_train.numpy()
X_test = dataset_test.data.numpy()
X_train = np.reshape(X_train, (60000, 784))
X_test = np.reshape(X_test, (10000, 784))
Y_train = dataset_train.targets.numpy()
Y_test = dataset_test.targets.numpy()
return X_train, Y_train, X_test, Y_test
def get_distance(x1, x2):
return np.linalg.norm(x1 - x2)
def get_vec(K, x, train_x, train_y):
res = []
for i in range(len(train_x)):
dis = get_distance(x, train_x[i])
res.append([dis, train_y[i]])
res = sorted(res, key=(lambda t: t[0]))
return res[:K]
def knn(K):
train_x, train_y, test_x, test_y = load_data()
train_x, train_y, test_x, test_y = train_x[:30000], train_y[:30000], test_x[:1000], test_y[:1000]
cnt = 0
for i in range(len(test_x)):
x = test_x[i]
y = test_y[i]
vec = get_vec(K, x, train_x, train_y)
weight = []
sum_distance = 0.0
for j in range(K):
sum_distance += vec[j][0]
for j in range(K):
weight.append([1 - vec[j][0] / sum_distance, vec[j][1]])
num = []
for j in range(K):
num.append(weight[j][1])
num = list(set(num))
final_res = []
for j in range(len(num)):
res = 0.0
for k in range(len(weight)):
if weight[k][1] == num[j]:
res += weight[k][0]
final_res.append([res, num[j]])
final_res = sorted(final_res, key=(lambda e: e[0]), reverse=True)
if y == final_res[0][1]:
cnt = cnt + 1
print(y, final_res[0][1])
print('accuracy:', cnt / len(test_x))
if __name__ == '__main__':
K = 10
knn(K)
|