IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> torch.max() 函数 -> 正文阅读

[人工智能]torch.max() 函数

torch.max() 函数

最近在玩图像目标分类问题,涉及到一个 torch.max() 函数,来记录一下

1、函数说明

output = torch.max(input, dim)

输入:

  • input:是softmax函数输出的一个tensor
  • dim:是max函数索引的维度0/10是每列的最大值,1`是每行的最大值

输出:

  • 函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。

在多分类任务中我们并不需要知道各类别的预测概率,所以返回值的第一个tensor对分类任务没有帮助,而第二个tensor包含了预测最大概率的索引,所以在实际使用中我们仅获取第二个tensor即可。

2、函数举例

下面举个例子来理解这个函数的用法:

import torch

a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])

print(a.shape)

print(a)

输出:

torch.Size([3, 4])
tensor([[ 1,  5, 62, 54],
        [ 2,  6,  2,  6],
        [ 2, 65,  2,  6]])

每行的最大值:(每列同理)

values, indexes = torch.max(a, 1) # 按照每一行取最大值
print(values)
print(indexes)

输出:

tensor([62,  6, 65]) # 每一行的最大值
tensor([2, 3, 1]) # 最大值对应的下标(下标从0开始,如果有相同的元素,取最后一个)

3、分类准确率计算问题

计算准确率时,我们需要看 下标index 与 label标签 是否相等,所以我们只需要 torch.max() 返回的 indexes 即可,即 torch.max(a, 1)[1]

我们已知 predict 的tensor,label 的tensor,将其转换为 numpy 数组

predict = torch.tensor([[1, 5, 62, 54], [2, 6, 2, 6], [2, 65, 2, 6]])
label = torch.tensor([[1],[3],[2]])

values, indexes = torch.max(predict, 1)
print(values)
print(indexes)


values, indexes = torch.max(label, 1)
print(values)
print(indexes)

pred_y = torch.max(predict, 1)[1].numpy() #  softmax函数输出
label_y = torch.max(label, 1)[1].numpy() #  样本标签
accuracy = (pred_y == label_y).sum() / len(label_y) # 准确率
print(accuracy)

输出:

tensor([62,  6, 65])
tensor([2, 3, 1])

tensor([1, 3, 2])
tensor([0, 0, 0])

0.0
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-10 22:31:02  更:2022-03-10 22:32:34 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 17:38:47-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码