IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> cuda_knn vs torch.topk() -> 正文阅读

[人工智能]cuda_knn vs torch.topk()

发现一个快速的KNN gpu算法
KNN_CUDA

DGCNN KNN

之前的knn算法用的都是DGCNN里面的torch.topk(),但是效率很低,网络训练既占用显存又慢;代替的pointnet2中的ball_query方法又不是严格的knn,可能会导致训练精度变低。

### DGCNN knn
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx, pairwise_distance
    
def get_graph_feature(x, k):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    idx, _ = knn(x, k=k)
    device = torch.device('cuda')
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    x = x.transpose(2, 1).contiguous()
    neighbor = x.view(batch_size * num_points, -1)[idx, :]
    neighbor = neighbor.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((neighbor-x, neighbor), dim=3).permute(0, 3, 1, 2)  # local and global all in
    return feature

KNN_CUDA

在PoinTr代码中发现,dgcnn_group.py里实现的edge_conv不一样

  1. 用了cuda_knn取代了自己写knn
  2. 加了with torch.no_grad()
from knn_cuda import KNN
knn = KNN(k=16, transpose_mode=False)
def get_graph_feature(coor_q, x_q, coor_k, x_k):
    # coor: bs, 3, np, x: bs, c, np
    k = 16
    batch_size = x_k.size(0)
    num_points_k = x_k.size(2)
    num_points_q = x_q.size(2)
    with torch.no_grad():
        _, idx = knn(coor_k, coor_q)  # bs k np
        assert idx.shape[1] == k
        idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k
        idx = idx + idx_base
        idx = idx.view(-1)
    num_dims = x_k.size(1)
    x_k = x_k.transpose(2, 1).contiguous()
    feature = x_k.view(batch_size * num_points_k, -1)[idx, :]
    feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous()
    x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k)
    feature = torch.cat((feature - x_q, x_q), dim=1)
    return feature

next work

准备用实际模型测一下这两种实现的区别,精度和速度上。
尽量在这周2022/3/27之前完成。
稍微修改了以下PoinTr中的代码,是他和dgcnn中的输入输出一致。

from knn_cuda import KNN
knn_cuda = KNN(k=20, transpose_mode=False)
def get_graph_feature_v3(x, k=20, idx=None):
    k = 20
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.contiguous()
    with torch.no_grad():
        _, idx = knn_cuda(x, x)  # bs k np
        assert idx.shape[1] == k
        idx = idx.transpose(1,2).contiguous()
        idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
        idx = idx + idx_base
        idx = idx.view(-1)
    num_dims = x.size(1)
    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) # (b,n,k,c)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() #(b,c,n,k)
    return feature

简单的在linxu下测试了运行时间

if __name__ == "__main__":
    import timeit
    x = torch.rand(2, 3, 1024).cuda()
    start=timeit.default_timer()
    feature1 = get_graph_feature(x)
    end=timeit.default_timer()
    print('Running time: %s Seconds'%(end-start))
    start=timeit.default_timer()
    feature2 = get_graph_feature_v3(x)
    end=timeit.default_timer()
    print('Running time: %s Seconds'%(end-start))

Running time: 0.2865183869998873 Seconds
Running time: 0.0008672099993418669 Seconds
差距极大

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-22 20:35:19  更:2022-03-22 20:40:07 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 1:30:04-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码