torch.max() 函数
最近在玩图像目标分类问题,涉及到一个 torch.max() 函数,来记录一下
1、函数说明
output = torch.max(input, dim)
输入:
- input:
是softmax函数输出的一个 tensor - dim:
是max函数索引的维度 0/1, 0是每列的最大值, 1`是每行的最大值
输出:
- 函数会返回两个
tensor ,第一个tensor 是每行的最大值;第二个tensor 是每行最大值的索引。
在多分类任务中我们并不需要知道各类别的预测概率,所以返回值的第一个tensor 对分类任务没有帮助,而第二个tensor 包含了预测最大概率的索引,所以在实际使用中我们仅获取第二个tensor 即可。
2、函数举例
下面举个例子来理解这个函数的用法:
import torch
a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
print(a.shape)
print(a)
输出:
torch.Size([3, 4])
tensor([[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 2, 65, 2, 6]])
每行的最大值:(每列同理)
values, indexes = torch.max(a, 1)
print(values)
print(indexes)
输出:
tensor([62, 6, 65])
tensor([2, 3, 1])
3、分类准确率计算问题
计算准确率时,我们需要看 下标index 与 label标签 是否相等,所以我们只需要 torch.max() 返回的 indexes 即可,即 torch.max(a, 1)[1]
我们已知 predict 的tensor,label 的tensor,将其转换为 numpy 数组
predict = torch.tensor([[1, 5, 62, 54], [2, 6, 2, 6], [2, 65, 2, 6]])
label = torch.tensor([[1],[3],[2]])
values, indexes = torch.max(predict, 1)
print(values)
print(indexes)
values, indexes = torch.max(label, 1)
print(values)
print(indexes)
pred_y = torch.max(predict, 1)[1].numpy()
label_y = torch.max(label, 1)[1].numpy()
accuracy = (pred_y == label_y).sum() / len(label_y)
print(accuracy)
输出:
tensor([62, 6, 65])
tensor([2, 3, 1])
tensor([1, 3, 2])
tensor([0, 0, 0])
0.0
|