定义
- torch.squeeze(input,dim)
- torch.unsqueeze(input,dim)
用法
- torch.squeeze():对输入数据的维度dim进行压缩,默认删除所有为1的维度
- torch.unsqueeze():对输入数据的维度dim进行扩充,dim为必填参数
示例
torch.squeeze()
import torch
a = torch.tensor([[2],[3],[1],[0],[5],[6]])
print(torch.squeeze(a))
print(torch.squeeze(a,0))
print(torch.squeeze(a,1))
>>>tensor([2, 3, 1, 0, 5, 6])
tensor([[2],
[3],
[1],
[0],
[5],
[6]])
tensor([2, 3, 1, 0, 5, 6])
torch.unsqueeze()
import torch
a = torch.tensor([[2],[3],[1],[0],[5],[6]])
print(torch.unsqueeze(a,0))
print(torch.unsqueeze(a,1))
>>>tensor([[[2],
[3],
[1],
[0],
[5],
[6]]])
tensor([[[2]],
[[3]],
[[1]],
[[0]],
[[5]],
[[6]]])
?
|