K近邻算法 (KNN)
主要思路:计算每个点和某点的距离,取距离最短的K个点的下标即可。 下面是个完整示例,代码复制即可运行
import torch
import time
def coordinate_gen(n):
"""
生成n个三位点
return tensor
dim:n*3
"""
xyz = torch.rand(size=(n,3))
return xyz
def time_cost(f):
def run_time(*args,**kwargs):
start = time.time()
res = f(*args,**kwargs)
run_times = time.time()-start
print("程序执行时间:%.6f s"%(run_times))
return res
return run_time
@time_cost
def knn(xyz,xyzs,k=3):
"""
xyz:key point
xyzs:all points
找某点的k近邻个点
return 近邻点的下标列表
"""
idx = [0]*k
distance = torch.sum((xyzs[:,:3]-xyz[:,:3])**2,dim=-1)
for i in range(k):
idx[i] = torch.argmin(distance,dim=0)
distance[int(torch.argmin(distance,dim=0))] = float('inf')
idx = [int(i) for i in idx]
return idx
if __name__ == "__main__":
print('-' * 20, '测试开始', '-' * 20)
N,k = map(int,input("输入生成点数 和 k的值:").split())
xyzs = coordinate_gen(N)
xyz = torch.rand(size=(1,3))
print("生成的点如下:\n",xyzs,"\n随机生成key point:",xyz)
print(knn(xyz,xyzs=xyzs,k=k))
print('-'*20,'测试结束','-'*20)
最远点采样(FPS)
def farthest_point_sample(data,npoints):
"""
Args:
data:输入的tensor张量,排列顺序 N,D
Npoints: 需要的采样点
Returns:data->采样点集组成的tensor,每行是一个采样点
"""
N,D = data.shape
xyz = data[:,:3]
centroids = torch.zeros(size=(npoints,))
dictance = torch.ones(size=(N,))*1e10
farthest = torch.randint(low=0,high=N,size=(1,))
for i in range(npoints):
centroids[i] = farthest
centroid = xyz[farthest,:]
dict = ((xyz-centroid)**2).sum(dim=-1)
mask = dict < dictance
dictance[mask] = dict[mask]
farthest = torch.argmax(dictance,dim=-1)
print(centroids.type(torch.long))
data= data[centroids.type(torch.long)]
return data
|