PyTorch学习笔记:torch.max——数组的最大值
torch.max()——数组的最大值
torch.max()有两种形式
形式Ⅰ
torch.max(input) → Tensor
功能:输出数组的最大值
注意:
- 只有一个输入,只需要输入一个数组
- 该方式也可以通过
a.max() 实现,后者是求数组a 的最大值
形式Ⅱ
torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)
功能:按指定维度判断,返回数组的最大值以及最大值处的索引
输入:
input :待判定的数组dim :给定的维度keepdim :如果指定为True ,则输出的张量数组维数和输入一致,并且除了dim 维度是1,其余的维度大小和输入数组维度大小一致。如果改为False ,则相当于将True 的结果压缩了(删去了大小是1的dim 维度)。两者的差别就在于是否保留dim 维度。
注意:
- 如果在指定的维度中,有多个重复的最大值,则返回第一个最大值的索引
- 该函数返回由最大值以及最大值处的索引组成元组
(max,max_indices) - 该函数也可以通过
a.max() 实现,后者是求数组a 的最大值,只需要再指明dim 以及keepdim 即可 - 输出的索引是对应的
dim 维度上的索引,注意含义。
上述两种函数形式本质区别就是有没有指出dim ,如果未指出dim ,则返回整个数组的最大值,不返回索引。如果指出了dim ,则在指定的维度上搜索最大值,返回最大值以及索引。
代码案例
一般用法
import torch
a=torch.arange(10).reshape(2,5)
b=torch.max(a)
c=torch.max(a,0)
print(a)
print(b)
print(c)
输出
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
tensor(9)
torch.return_types.max(
values=tensor([5, 6, 7, 8, 9]),
indices=tensor([1, 1, 1, 1, 1]))
keepdim 定为True 或者Flase 的区别
import torch
a=torch.arange(10).reshape(2,5)
b=torch.max(a,0,True)
c=torch.max(a,0)
print(a.shape)
print(b[0].shape)
print(c[0].shape)
输出,只有维度不同
torch.Size([2, 5])
torch.Size([1, 5])
torch.Size([5])
不同的dim 对结果的影响,这里以input 的维数是3为例,维数更多的可以类推,首先定义好数组input 。
import torch
a=torch.arange(32).reshape(2,4,4)
print(a)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]])
dim 为0
b=torch.max(a,0)
print(b)
输出:
torch.return_types.max(
values=tensor([[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]),
indices=tensor([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]]))
dim 为1
c=torch.max(a,1)
print(c)
输出
torch.return_types.max(
values=tensor([[12, 13, 14, 15],
[28, 29, 30, 31]]),
indices=tensor([[3, 3, 3, 3],
[3, 3, 3, 3]]))
dim 为2
d=torch.max(a,2)
print(d)
输出
torch.return_types.max(
values=tensor([[ 3, 7, 11, 15],
[19, 23, 27, 31]]),
indices=tensor([[3, 3, 3, 3],
[3, 3, 3, 3]]))
扩展
在分类任务中经常用到该函数,softmax 函数输出得到的概率再经过torch.max 函数得到最终的预测结果(预测结果一般和索引值一一对应,因此可以用索引值来表示预测结果),可以进一步与标签做比较,得到准确率。
import numpy as np
import torch
a = torch.tensor(np.random.rand(16, 20))
pre = torch.max(a,dim=1)[1]
print(pre)
print(pre.shape)
输出,共有16个结果(对应batch_size=16)
tensor([11, 10, 1, 9, 19, 19, 15, 14, 5, 2, 17, 14, 13, 15, 15, 17])
torch.Size([16])
官方文档
torch.max():https://pytorch.org/docs/stable/generated/torch.max.html?highlight=torch%20max#torch.max
|