IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于pytorch简单实现稀疏3d卷积(SECOND) -> 正文阅读

[人工智能]基于pytorch简单实现稀疏3d卷积(SECOND)

卷积计算是深度学习模型的常见算子,在3D项目中,比如点云分割,由于点云数据是稀疏的,使用常规的卷积计算,将会加大卷积计算时间,不利于模型推理加速。由此SECOND网络提出了稀疏卷积的概念。

稀疏卷积的主要理念就是由正常的全部数据进行卷积运算,优化了为只计算有效的输入点的卷积结果。稀疏卷积的思路网上已经有很多简明扼要的文章,比如知乎的这一篇就很清晰,本文就是根据这一篇的思路实现的一个简单的稀疏卷积流程。建议先看一下先了解。

稀疏卷积的输入是有效输入点的索引坐标(哈希表)和对应的features值,大概流程是:

1,根据输入坐标得到输出点的索引坐标(哈希表)。每一个输入点,可以最多和kernel个点(比如3d卷积,kernel=3,则kernel点个数是3*3*3=27)相乘,得到kernel点个数的输出坐标。所以rulebook可以建立成kenel点个数的字典,每个kenel对应一个或多个输入点索引和输出点索引。

2,将输入点和对应kernel点进行矩阵乘法,得到卷积结果。

3,将同一个输出点坐标的卷积结果进行累加,根据输出点索引与真实坐标的关系,将结果还原到输出位置,即完成了稀疏卷积运算。

下面是实现的一个简单示例代码,其中稀疏卷积结果和普通卷积结果进行了对比,误差为0。

输入坐标和输出点坐标的映射关系,是遍历每个输出点的坐标,根据输出点坐标,kernel,stride可以得到相关的kernel点个输出点的坐标,如果在有点输出点列表里面,则表示这是一个有效输出点,更新输出点索引哈希表和rulebook字典。

这种方法的时间复杂度较大,需要遍历所有输出点,后面有优化方案,直接有公式计算输入点对应的输出点坐标。但是可以大概看一下整体流程。

# -*- coding: utf-8 -*-
import time

import torch
import torch.nn as nn
import itertools
import numpy as np

def generate_sparse_data(shape,
                         num_points,
                         num_channels,
                         integer=False,
                         data_range=(-1, 1),
                         with_dense=True,
                         dtype=np.float32):
    dense_shape = shape
    ndim = len(dense_shape)
    num_points = np.array(num_points)
    batch_size = len(num_points)
    batch_indices = []
    coors_total = np.stack(np.meshgrid(*[np.arange(0, s) for s in shape]),
                           axis=-1)
    coors_total = coors_total.reshape(-1, ndim)
    for i in range(batch_size):
        np.random.shuffle(coors_total)
        inds_total = coors_total[:num_points[i]]
        inds_total = np.pad(inds_total, ((0, 0), (0, 1)),
                            mode="constant",
                            constant_values=i)
        batch_indices.append(inds_total)
    if integer:
        sparse_data = np.random.randint(data_range[0],
                                        data_range[1],
                                        size=[num_points.sum(),
                                              num_channels]).astype(dtype)
    else:
        sparse_data = np.random.uniform(data_range[0],
                                        data_range[1],
                                        size=[num_points.sum(),
                                              num_channels]).astype(dtype)

    res = {
        "features": sparse_data.astype(dtype),
    }
    if with_dense:
        dense_data = np.zeros([batch_size, num_channels, *dense_shape],
                              dtype=sparse_data.dtype)
        start = 0
        for i, inds in enumerate(batch_indices):
            for j, ind in enumerate(inds):
                dense_slice = (i, slice(None), *ind[:-1])
                dense_data[dense_slice] = sparse_data[start + j]
            start += len(inds)
        res["features_dense"] = dense_data.astype(dtype)
    batch_indices = np.concatenate(batch_indices, axis=0)
    res["indices"] = batch_indices.astype(np.int32)
    return res



def get_Pin2Pout_Rulebook_3d(n,ho, wo,do,ks,stride, in_indice):
    '''
    根据有效的输入点位置,得到有效的输出点位置,并建立kernel, in_idx, out_indx字典关系。
    in_indice:有效点的坐标 [[hi,wi,ni],[hi1,wi1,ni1],...]

    return:
    offset, {k0:[[pin_idx, pout_idx],...], k2:[[pin_idx, pout_idx],...]}
    pout_indice, same to in_indice
    '''
    offset = {i: [] for i in range(ks**3)}
    pout_indice = []
    out_count = 0
    for b, i, j, d in itertools.product(range(n), range(ho), range(wo), range(do)):
        flag = False
        for kh, kw, kd in itertools.product(range(ks),range(ks),range(ks)):
            if [stride*i + kh, stride*j + kw,stride*d+kd,b] in in_indice:
                flag = True
                offset[kh*ks*ks+kw*ks+kd].append(
                    [in_indice.index([ stride*i + kh, stride*j + kw,stride*d+kd,b]), out_count])  # [in_index,out_index]
        if flag == True:
            pout_indice.append([b, i, j,d])
            out_count += 1
    return offset, pout_indice

def get_output_3d(rulebook,in_data,weight_data,out_indice,out_data):
    '''
    遍历每一个kernel, 通过查找pin_idx和对应的kernel, 矩阵乘得到pout的值,并放回位置。
    同一个pout结果累加
    '''
    for key in rulebook.keys():
        cur_book=rulebook[key]
        w_data=weight_data[key]
        for i in range(len(cur_book)):
            x=in_data[cur_book[i][0],:]
            n,ho,wo,do=out_indice[cur_book[i][1]]
            out_data[n,:,ho,wo,do]+=np.matmul(x,w_data)
    return out_data


def test_conv3d(sparse_dict,ci,co,kernel,stride):
    features=sparse_dict['features']
    features_dense=sparse_dict['features_dense']
    in_indices=sparse_dict['indices'] #

    conv3d=nn.Conv3d(ci,co, kernel,stride=stride, bias=False)
    weight = conv3d.weight.detach().numpy()  # co,ci,kh,kw
    weight = weight.reshape(co, ci, kernel ** 3).transpose(2, 1, 0)


    ref_out=conv3d(torch.tensor(features_dense))

    bs,co,ho,wo,do=ref_out.shape
    spconv_out=np.zeros([bs,co,ho,wo,do])
    rulebook,pout_indice=get_Pin2Pout_Rulebook_3d(bs,ho,wo,do,kernel,stride, in_indices.tolist())
    spconv_out=get_output_3d(rulebook,features,weight,pout_indice,spconv_out)

    dif=np.abs(ref_out.detach().numpy()-spconv_out)
    print('max diff is:',round(np.max(dif),4))
    print('sparse conv3d test over')
    return spconv_out

if __name__ =="__main__":
    shapes=(9,19,18)  # conv3d:(h,w,d)
    bs=1  #batch_size
    ks=3 #kernel_size
    stride=2
    ci=7
    co=32
    num_points = [100] * bs  # 100个有效点个数
    sparse_dict=generate_sparse_data(shapes,num_points,ci)
    test_conv3d(sparse_dict, ci, co, ks,stride)  

备注

该示例代码默认无padding,可以任意定义输入shapes, 其中generate_sparse_data是spconv的github代码里面给产生的稀疏数据代码。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-28 11:16:07  更:2021-11-28 11:17:56 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 2:15:03-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码