?作为一个kdtree建立和knn搜索笔记。
如有错误欢迎留言,谢谢。
import numpy as np
import math
class Node:
def __init__(self,elt=None,LL=None,RR=None,split=None):
self.left=LL #左子树
self.right=RR #右子树
self.split=split #划分的超平面空间(就是切割面)
self.elt=elt #具体的数据点
def building_tree(root, data): #建树
if len(data) < 1: #如果没有数据点传进来,就直接返回。大白话:没有要切割的点了
return
maxvar = 0 #最大方差
split=0 #对于二维来说0是垂直于x轴1是垂直于y轴
dim = data.shape[1] #获取维度
item=[]
data_list=data.tolist() #矩阵转化为列表
datat=data.transpose() #矩阵的转秩,方便后面数的提取
for i in range(dim):
item.clear()
for t in datat[i]: #获取x(y)轴所有的值计算方差
item.append(t)
var = culvar(item)
if maxvar < var: #选出方差最大的那一个为超平面
split = i
maxvar = var
print("超平面:%d,最大方差:%d"%(split,maxvar))
mediam = data.shape[0] // 2 #取出中位数的下标
data_list.sort(key=lambda x:x[split]) #排序
elt = data_list[mediam] #取出中位数
print(elt)
root=Node(elt=elt,split=split) #建立一个节点
#print("当前数据点:",np.array(data_list[0:mediam]))
# 不断递归,取出x轴小于中位数的数据点作为下一个平面内的数据点
root.left = node.building_tree(root.left,np.array(data_list[0:mediam]))
#这里是要用你实例化的对象建树了!
root.right = node.building_tree(root.right,np.array(data_list[mediam+1:]))
return root
def search(target,root):
NN=root.elt #获取根节点的数据点
#print("NN:",root.split)
#print("target:", type(target))
min_dis=culdistance(target,NN) #计算最坏距离
nodelist=[]
temp_root=root
while temp_root: #直到循环到叶子节点结束
nodelist.append(temp_root) #模拟堆栈,先进后出,给后面的回溯做铺垫
splt = temp_root.split #取出我这个数据点超平面
# print("split:",temp_root.split)
dist=culdistance(target,temp_root.elt)
if dist<min_dis: #如果有比最坏距离还要小的就存下距离和对应的数据点
min_dis=dist
NN=temp_root.elt
#在现在这个节点的超平面下,目标点和我现在的节点里的数据点距离,判断我要给的节点是左节点还是右
if target[splt]<=temp_root.elt[splt]:
temp_root=temp_root.left
else:
temp_root = temp_root.right
while nodelist: #回溯从叶子回到根(可能会跑到根的另一端找)
back_root=nodelist.pop()
splt=back_root.split
# 计算在这个超平面内的距离(最短距离:就是目标点垂直于超平面的距离)
if abs(target[splt]-back_root.elt[splt])<min_dis:
#只要被我的圆所包围的所有超平面我都要去遍历的
if target[splt]>back_root.elt[splt]:#叶子已经找过了肯定直接找下一个节点了
temp_root=back_root.right
else:
temp_root=back_root.left
if temp_root: #只要不是叶子被弹出这个都会执行去看看是否有点比我刚刚找的更近
nodelist.append(temp_root)
cur_dist=culdistance(target,temp_root.elt)
if cur_dist<min_dis:
min_dis=cur_dist
NN = temp_root.elt
return NN,min_dis
def culdistance(p1,p2):
sum=0
for i in range(len(p1)):
sum=sum+(p1[i]-p2[i])*(p1[i]-p2[i])
return math.sqrt(sum)
def culvar(value):
value=np.array(value)
ex = value.mean(axis=0)
ex2 = pow(value,2).mean(axis=0)
return ex2-ex*ex
if __name__ == "__main__":
node=Node
data=np.array([[2,3],[5,4],[7,2],[8,1],[9,6],[4,7]])
#print(data)
tree_root=node.building_tree(None,data)
point,min_distance=search([2.1,3.1],tree_root)
print(point,min_distance)
结果:?
?
?
?
感谢大佬:(52条消息) kd-tree的python实现_Flying Dreams-CSDN博客_kdtree python
|