IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> torch.argmax(input dim keepdim=False) -> 正文阅读

[人工智能]torch.argmax(input dim keepdim=False)

导读

最近有时间看一些目标检测项目的代码(基于Pytorch),里边很多Pytorch的相关操作都忘记了,特来此记录一下,用以加深记忆,而且还能以备一样处境的同学前来查询。今天的主角是torch.argmax(input, dim, keepdim=False)。

官方文档地址

https://pytorch.org/docs/stable/generated/torch.argmax.html

torch.argmax(input) → LongTensor

Returns the indices of the maximum value of all elements in the input tensor.

根据官方的解释,该函数可以返回输入张量中所有元素的最大值的索引。当然这只是最初级的用法,根据输入参数的不同,其返回的结果也不同。下面我们一起了解它的参数都有哪些作用。

参数解析

Parameters

  • input (Tensor) : the input tensor.
  • dim (int) :the dimension to reduce. If None, the argmax of the flattened input is returned.
  • keepdim (bool) : whether the output tensor has dim retained or not. Ignored if dim=None.

这是官网上对参数的解释,input就是我们输入的要操作的张量;dim是我们选择的要在张量的哪个维度上进行计算,输出这个维度最大值的索引,这里一行元素变成一个索引,所以官网中用了reduce;keepdim是询问输出是否与输入保持一样的形状,默认是不保持(False)。

举例演示

首先输入一个张量,注意我们输入的这个张量的shape为[5, 9]

>>> x = torch.randn(5,9)
>>> print(x)
tensor([[ 0.3918,  0.3978,  0.2819, -0.8487, -1.0499,  0.2124, -1.3527, -1.5335,
          1.1050],
        [ 0.8450, -0.3717, -0.4705, -0.4024,  2.1019, -0.8545,  1.9085,  0.5792,
         -0.4279],
        [ 0.1993, -0.2887,  0.4467,  0.4878,  1.4934, -1.3862,  0.3576, -0.2363,
         -2.0700],
        [ 0.0536,  0.9385,  1.2661, -0.3469, -0.5772, -0.7822,  0.8315, -1.7256,
         -0.4979],
        [ 1.1592, -0.1604,  0.2798,  0.5974,  0.1782, -2.3354, -1.7775, -0.8366,
          1.8993]])

接下来,现在第一个维度上进行操作

>>> torch.argmax(x,dim=0)
tensor([4, 3, 3, 4, 1, 0, 1, 1, 4])

第一个维度是行,即按行计算,我们看到结果输出的维度为9,正好是输入张量x的列数。torch.argmax()的计算方式如下:
每次在所有行的相同位置取元素,然后计算取得元素集合的最大值索引。
第一次取所有行的第一位元素,x[:, 0], 得到

tensor([0.3918, 0.8450, 0.1993, 0.0536, 1.1592])

第二次取所有行的第二位元素,x[:, 1], 得到

tensor([0.3978, -0.3717, -0.2887, 0.9385, -0.1604])

依次类推,x有9列,我们也可以取9次,所有取的结果如下:

tensor([ 0.3918,  0.8450,  0.1993,  0.0536,  1.1592])
tensor([ 0.3978, -0.3717, -0.2887,  0.9385, -0.1604])
tensor([ 0.2819, -0.4705,  0.4467,  1.2661,  0.2798])
tensor([-0.8487, -0.4024,  0.4878, -0.3469,  0.5974])
tensor([-1.0499,  2.1019,  1.4934, -0.5772,  0.1782])
tensor([-1.3527,  1.9085,  0.3576,  0.8315, -1.7775])
tensor([-1.5335,  0.5792, -0.2363, -1.7256, -0.8366])
tensor([ 1.1050, -0.4279, -2.0700, -0.4979,  1.8993])

然后分别计算以上每个张量中元素的最大值的索引,便得到tensor([4, 3, 3, 4, 1, 0, 1, 1, 4])

同理,按照列来操作也是一样的思路,这里就不详细说了,看结果:

>>> torch.argmax(x,dim=1)
tensor([8, 4, 4, 2, 8])

经过上边例子的演示,我们可以知道torch.argmax(input,dim)可以返回input中dim维度上的最大值索引。
我们给x在目标检测中赋予具体的含义,假如x的形状为[num_bbox, anchor],那么x便是5个预测框分别与9个anchor计算得到的交并比,我们要选出来与预测框交并比最大的那个anchor,用来回归预测框越来越接近GT。这时候就要用到torch.argmax()找到与bbox交并比最大的anchor的序号。

>>> torch.argmax(x,dim=1)
tensor([8, 4, 4, 2, 8])

即与第一个预测框交并比最大的是第9个anchor,与第二个预测框交并比最大的是第5个anchor…

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-10 22:31:02  更:2022-03-10 22:35:33 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 16:52:18-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码