Pytorch中的 flatten,squeeze 和 unsqueeze 的区分
解释
flatten() 用于将数据展开。
squeeze() 用于将数据进行压缩,移除某个维度
- Compute torch.squeeze(input). It squeezes (removes) the size 1 and returns a tensor with all other dimensions of the input tensor.
unsqueeze() 用于将数据解压缩,扩充一个维度
- Compute torch.unsqueeze(input, dim). It inserts a new dimension of size 1 at the given dim and returns the tensor.
举例:
import torch
T = torch.ones(2,1,2)
print("Original Tensor T:\n", T )
print("Size of T:", T.size())
原数据 T 的输出:
原数据 T 的 flatten() 输出
T.flatten()
torch.flatten()
原数据 T 的 squeeze() 输出
T.squeeze(0)
troch.squeeze(T)
注意观察,两者的维度,(中括号个数)
原数据 T 的 unsqueeze() 输出
T.unsqueeze(dim=0)
torch.unsqueeze(T, dim=0)
参考链接
https://www.tutorialspoint.com/how-to-squeeze-and-unsqueeze-a-tensor-in-pytorch
https://pytorch.org/docs/stable/generated/torch.squeeze.html
|