学习代码的第一个拦路虎,在pytorch中非常常用的维度转换。
首先了解一下tensor的size是怎么来的,几个中括号就说明有几个维度,然后看第一个中括号里用逗号分隔开了几个元素,就是第一个维度的值,其他依次类推。例如Tensor([[[1,2,3],[4,5,6]]])中第一个中括号里为[[1,2,3],[4,5,6]],只有一个元素,第二个中括号内为[1,2,3],[4,5,6],有两个元素,第三个为[1,2,3],有三个元素。
view()
view变换维度,把原先tensor中的数据按行优先的顺序排成一个一维数据(这里应该是因为要求地址是连续存储的),然后按照输入参数要求,组合成其他维度的tensor。例如:
a=torch.Tensor([[[1,2,3],[4,5,6]]])
print(a.view(3,2))
tensor([[1., 2.],
[3., 4.],
[5., 6.]])
permute()
permute将tensor中任意维度调换。permute里的参数对应的是张量a的维度索引,permute的输入参数的维度必须与a一致,并且只能是0,1,2…,dim,能够一一对应地索引到a里面的维度。这个比view()稍微难理解一点,但只要明白每个数值对应维度的索引值,改变索引值最终改变输出值。a.permute(2,0,1)的意思是,把a的最后一个维度放到最前面。详细理解图见下:
b=torch.Tensor([[[1,2,3],[4,5,6]]])
print(b.size())
permuted=b.permute(2,0,1)
print(permuted.size())
print(permuted)
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 4]],
[[2, 5]],
[[3, 6]]], dtype=torch.int32)
transpose()
对于二维tensor,permute(1,0)做的就是转置,等价于transpose()。
transpose()选择tensor中两个维度进行转置。代码中transpose(1,2)表示将三维张量中的后两维转置。
c=torch.Tensor([[[1,2,3],[4,5,6]]])
print(c.transpose(1,2))
tensor([[[1., 4.],
[2., 5.],
[3., 6.]]])
|