1. 改变shape
torch.reshape()、torch.view()可以调整Tensor的shape,返回一个新shape的Tensor,torch.view()是老版本的实现,torch.reshape()是最新的实现,两者都是用来重塑tensor的shape的。view只适合对满足连续性条件(contiguous)的tensor进行操作,而reshape同时还可以对不满足连续性条件的tensor进行操作,具有更好的鲁棒性。view能干的reshape都能干,如果view不能干就可以用reshape来处理。
示例代码:
import torch
a = torch.rand(4, 1, 28, 28)
print(a.shape)
print(a.view(4 * 1, 28, 28).shape)
print(a.reshape(4 * 1, 28, 28).shape)
print(a.reshape(4, 1 * 28 * 28).shape)
输出结果:
torch.Size([4, 1, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 784])
注意:维度变换的时候要注意实际意义。
2. 增加维度
torch.unsqueeze(index)可以为Tensor增加一个维度,增加的这一个维度的位置由我们自己定义,新增加的这一个维度不会改变数据本身,只是为数据新增加了一个组别,这个组别是什么由我们自己定义。
比如定义了一个Tensor:
a = torch.randn(4, 1, 28, 28)
这个Tensor有4个维度,我们可以在现有维度的基础上插入一个新的维度,插入维度的index在[-a.dim()-1, a.dim()+1]范围内,并且当index>=0,则在index前面插入这个新增加的维度;当index < 0,则在index后面插入这个新增的维度。
示例代码:
print(a.shape)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(-1).shape)
print(a.unsqueeze(3).shape)
print(a.unsqueeze(4).shape)
print(a.unsqueeze(-4).shape)
print(a.unsqueeze(-5).shape)
print(a.unsqueeze(5).shape)
输出结果:
torch.Size([4, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 1, 28, 1, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 1, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
Traceback (most recent call last):
File "/home/lhy/workspace/mmdetection/my_code/pytorch_ws/tensotr_0.py", line 218, in <module>
print(a.unsqueeze(5).shape)
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
在执行a.unsqueeze(5)时报错,是因为超出了index的范围。
未完待续。。。。
参考:
PyTorch:view() 与 reshape() 区别详解_Flag_ing的博客-CSDN博客
Pytorch Tensor维度变换_洪流之源-CSDN博客
|