发现一个快速的KNN gpu算法 KNN_CUDA
DGCNN KNN
之前的knn算法用的都是DGCNN里面的torch.topk() ,但是效率很低,网络训练既占用显存又慢;代替的pointnet2中的ball_query方法又不是严格的knn,可能会导致训练精度变低。
def knn(x, k):
inner = -2*torch.matmul(x.transpose(2, 1), x)
xx = torch.sum(x**2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1)
idx = pairwise_distance.topk(k=k, dim=-1)[1]
return idx, pairwise_distance
def get_graph_feature(x, k):
batch_size = x.size(0)
num_points = x.size(2)
x = x.view(batch_size, -1, num_points)
idx, _ = knn(x, k=k)
device = torch.device('cuda')
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
_, num_dims, _ = x.size()
x = x.transpose(2, 1).contiguous()
neighbor = x.view(batch_size * num_points, -1)[idx, :]
neighbor = neighbor.view(batch_size, num_points, k, num_dims)
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
feature = torch.cat((neighbor-x, neighbor), dim=3).permute(0, 3, 1, 2)
return feature
KNN_CUDA
在PoinTr代码中发现,dgcnn_group.py 里实现的edge_conv 不一样
- 用了cuda_knn取代了自己写knn
- 加了
with torch.no_grad()
from knn_cuda import KNN
knn = KNN(k=16, transpose_mode=False)
def get_graph_feature(coor_q, x_q, coor_k, x_k):
k = 16
batch_size = x_k.size(0)
num_points_k = x_k.size(2)
num_points_q = x_q.size(2)
with torch.no_grad():
_, idx = knn(coor_k, coor_q)
assert idx.shape[1] == k
idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k
idx = idx + idx_base
idx = idx.view(-1)
num_dims = x_k.size(1)
x_k = x_k.transpose(2, 1).contiguous()
feature = x_k.view(batch_size * num_points_k, -1)[idx, :]
feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous()
x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k)
feature = torch.cat((feature - x_q, x_q), dim=1)
return feature
next work
准备用实际模型测一下这两种实现的区别,精度和速度上。 尽量在这周2022/3/27之前完成。 稍微修改了以下PoinTr中的代码,是他和dgcnn中的输入输出一致。
from knn_cuda import KNN
knn_cuda = KNN(k=20, transpose_mode=False)
def get_graph_feature_v3(x, k=20, idx=None):
k = 20
batch_size = x.size(0)
num_points = x.size(2)
x = x.contiguous()
with torch.no_grad():
_, idx = knn_cuda(x, x)
assert idx.shape[1] == k
idx = idx.transpose(1,2).contiguous()
idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
num_dims = x.size(1)
x = x.transpose(2, 1).contiguous()
feature = x.view(batch_size * num_points, -1)[idx, :]
feature = feature.view(batch_size, num_points, k, num_dims)
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
return feature
简单的在linxu下测试了运行时间
if __name__ == "__main__":
import timeit
x = torch.rand(2, 3, 1024).cuda()
start=timeit.default_timer()
feature1 = get_graph_feature(x)
end=timeit.default_timer()
print('Running time: %s Seconds'%(end-start))
start=timeit.default_timer()
feature2 = get_graph_feature_v3(x)
end=timeit.default_timer()
print('Running time: %s Seconds'%(end-start))
Running time: 0.2865183869998873 Seconds Running time: 0.0008672099993418669 Seconds 差距极大
|