| |
|
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
| -> 人工智能 -> Pytorch中的grid_sample算子功能解析 -> 正文阅读 |
|
|
[人工智能]Pytorch中的grid_sample算子功能解析 |
|
???????? 近期在一个模型从pytorch迁移到mindspore框架中遇到一个算子适配问题,pytorch中的grid_sample在mindspore中没有对应的算子,需要考虑自定义实现。查找pytorch官网发现grid_sample是一种特殊的采样算法。 调用接口为: torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)。 ???????? input参数是输入特征图tensor,也就是特征图,可以是四维或者五维张量,以四维形式为例(N,C,Hin,Win),N可以理解为Batch_size,C可以理解为通道数,Hin和Win也就是特征图高和宽。 ???????? grid包含输出特征图特征图的格网大小以及每个格网对应到输入特征图的采样点位,对应四维input,其张量形式为(N,Hout,Wout,2),其中最后一维大小必须为2,如果输入为五维张量,那么最后一维大小必须为3。为什么最后一维必须为2或者3?因为grid的最后一个维度实际上代表一个坐标(x,y)或者(xy,z),对应到输入特征图的二维或三维特征图的坐标维度,xy取值范围一般为[-1,1],该范围映射到输入特征图的全图。 ???????? mode为选择采样方法,有三种内插算法可选,分别是'bilinear'双线性差值、'nearest'最邻近插值、'bicubic' 双三次插值。 ???????? padding_mode为填充模式,即当(x,y)取值超过输入特征图采样范围,返回一个特定值,有'zeros' 、 'border' 、 'reflection'三种可选,一般用zero。 ???????? align_corners为bool类型,指设定特征图坐标与特征值对应方式,设定为TRUE时,特征值位于像素中心。 ???????? 要理解grid_sample是如何工作的,最好就是进行简单的复现。假设输入shape为(N,C,H,W),grid的shape设定为(N,H,W,2),以双线性差值为例进行处理。首先根据input和grid设定,输出特征图tensor的shape为(N,C,H,W),输出特征图上每一个cell上的值由grid最后一维(x,y)确定。那么如何计算输出tensor上每一个点的值?首先,通过(x,y)找到输入特征图上的采样位置,由于xy取值范围为[-1,1],为了便于计算,先将xy取值范围调整为[0,1]。通过(w-1)*(x+1)/2、(wh-1)*(y+1)/2将xy映射为输入特征图的具体坐标位置。将xy映射到特征图实际坐标后,取该坐标附近四个角点特征值,通过四个特征值坐标与采样点坐标相对关系进行双线性插值,得到采样点的值。 注意:xy映射后的坐标可能是输入特征图上任意位置。假设输出特征图上(2,2)坐标位置上的值采样位置可能为输入特征图上(3,4)位置,xy越小越靠近输入特征图左上角,越大则越靠近右下角。
?????????基于上面的思路,可以进行一个简单的自定义实现。根据指定shape生成input和grid,使用pytorch中的grid_sample算子生成output。之后取grid中的第一个位置中的xy,根据xy从input中通过双线性插值计算出output第一个位置的值。
运行结果:
? ?????????从输出结果上看,与pytorch基本一致,由于仅仅做简单验证,这里没有对超出[-1,1]范围的xy值做处理,只能处理四维input,五维input的实现思路与这里基本一致。 |
|
|
|
|
| 上一篇文章 下一篇文章 查看所有文章 |
|
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
| 360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年12日历 | -2025/12/5 22:28:51- |
|
| 网站联系: qq:121756557 email:121756557@qq.com IT数码 |