dim的定义
dim 表示维度
x = torch.randn(2, 3, 3)
print(x)
print(x.size())
print(x.dim())
输出:
tensor([[[-1.6943, -2.1487, 1.2332],
[-0.2261, -0.1596, 1.5513],
[ 2.0383, -0.6982, -2.1481]],
[[ 0.4201, -2.7373, 0.2424],
[-1.1152, 1.3682, -1.8322],
[ 0.1957, -0.2920, 0.1845]]])
torch.Size([2, 3, 3])
3
这样看着不是很清晰,但如果将[] 格式化:
[
[
[-1.6943, -2.1487, 1.2332],
[-0.2261, -0.1596, 1.5513],
[ 2.0383, -0.6982, -2.1481]
],
[
[ 0.4201, -2.7373, 0.2424],
[-1.1152, 1.3682, -1.8322],
[ 0.1957, -0.2920, 0.1845]
]
]
- 维度
(2, 3, 3) 就很明显了, 是从矩阵的外部到内部 - 而
x.dim() = 3 意味着x 有三个维度, dim = (0, 1, 2) ,
0 对应着x.size() 中的(2 , 3, 3)1 对应着x.size() 中的(2, 3 , 3)2 对应着x.size() 中的(2, 3, 3 )
dim的理解
当dim = 0 时, 指的是 x(3, 3) 也就是:
x = torch.randn(2, 3, 3)
print(x)
for i in x:
print(i)
print(i.size())
输出:
tensor([[[-1.4251, -0.8321, 1.0230],
[ 0.2008, 0.5929, -0.7696],
[-0.3721, -1.0837, -0.6642]],
[[-0.5337, 0.7808, 0.4419],
[-0.4683, 0.3847, 0.0747],
[ 1.0156, -0.4933, 1.5340]]])
tensor(
[
[-1.4251, -0.8321, 1.0230],
[ 0.2008, 0.5929, -0.7696],
[-0.3721, -1.0837, -0.6642]
]
)
torch.Size([3, 3])
tensor(
[
[-0.5337, 0.7808, 0.4419],
[-0.4683, 0.3847, 0.0747],
[ 1.0156, -0.4933, 1.5340]
]
)
torch.Size([3, 3])
所以说当dim=0 时, 相当于去除x 中的dim = 0 的维度
验证
- torch.argmax(tensor)
返回tensor中值最大的数的下标, 比较的是同型张量 Example: >>> x = torch.tensor([1, 5, 8, 4, 6]) >>> torch.argmax(x) tensor(2)
import torch
x = torch.randn(2, 3, 3)
print(x)
print('='*50, end='\n\n')
for i in x:
print(i)
print(i.size())
print('='*50, end='\n\n')
print(x.size())
print(x.dim())
print('='*50, end='\n\n')
y = torch.argmax(x, dim=0)
print(y)
print(y.size())
输出:
tensor(
[
[
[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
],
[
[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]
]
]
)
==================================================
tensor([[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]])
torch.Size([3, 3])
tensor([[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]])
torch.Size([3, 3])
==================================================
torch.Size([2, 3, 3])
3
==================================================
tensor([[1, 0, 0],
[0, 1, 1],
[0, 1, 0]])
torch.Size([3, 3])
-
分析一下 y[0] = [1, 0, 0] , 为什么呢? 有两种想法:
- 它比较的是
[-1.3918, 0.0620, -0.4111] 与 [ 0.7135, -0.5290, -0.7656] 其中: [-1.3918, 0.7135], 0.7135比较大, 所以返回 1 [0.0620, -0.5290], 0.0620比较大, 所以返回 0 [-0.4111, -0.7656], -0.4111比较大, 所以返回 0 - 如果比较的是
x[i] 中的每一列, 得到的是2x3 的输出, 例如 x[0] : [-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
比较每一列, 经过torch.argmax 得到的是 [1, 0, 2] -
如果按照去掉dim = 0 的部分, x' : [
[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
],
[
[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]
]
也就是两个size = (3, 3) 的tensor , 这为什么不是第二种情况就比较合理了 因为比较的是两个tensor , 而第二种情况是分别在一个tensor 内的比较, 再将两个tensor 的比较结果合并
- 总结: 比较的是去掉指定维度后的第一个维度, 比如这里的:(
2 , 3, 3) -> (3 , 3), 得到的结果的size 是去掉指定dim 的size -
如果只有两个维度, 或许会好理解一些: import torch
x = torch.randn(2,3)
print(x)
y = torch.argmax(x, dim=0)
print(y)
print(y.size())
输出: tensor(
[
[ 0.0251, -0.3640, 0.1965],
[ 0.6902, 0.9846, 0.2035]
]
)
tensor([1, 1, 1])
torch.Size([3])
去掉dim = 0 , 比较的就是 [ 0.0251, -0.3640, 0.1965] 和 [ 0.6902, 0.9846, 0.2035] dim = (2 , 3) -> dim(3 ) -
这时候再回来看上面3个维度的例子: [
[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
],
[
[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]
]
比较两者时相当于在下面的tensor 做torch.argmax() [
[-1.3918, 0.0620, -0.4111],
[ 0.7135, -0.5290, -0.7656]
]
|