view reshape 重塑
import torch
a = torch.rand(4, 1, 28, 28)
print("a.shape:", a.shape)
print("a.view(4, 28*28):", a.view(4, 28*28).shape)
print("a.view(4*28, 28):", a.view(4*28, 28).shape)
print("a.view(4, 28, 28):", a.view(4, 28, 28).shape)
print("a.reshape(4, 28, 28):", a.reshape(4, 28, 28).shape)
- 注意:view的维度乘积与原来维度乘积不同会报错
- view可用reshape相互替换
unsqueeze 展开
import torch
a = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)
print("a.unsqueeze(0).shape", a.unsqueeze(0).shape)
print("a.unsqueeze(-1).shape", a.unsqueeze(-1).shape)
print("a.unsqueeze(4).shape", a.unsqueeze(4).shape)
print("a.unsqueeze(-4).shape", a.unsqueeze(-4).shape)
print("a.unsqueeze(-5).shape", a.unsqueeze(-5).shape)
squeeze 挤压
import torch
a = torch.rand(1, 32, 1, 1)
print("a.squeeze().shape:", a.squeeze().shape)
print("a.squeeze(0).shape:", a.squeeze(0).shape)
print("a.squeeze(-1).shape:", a.squeeze(-1).shape)
print("a.squeeze(1).shape:", a.squeeze(1).shape)
print("a.squeeze(-4).shape:", a.squeeze(-4).shape)
expand 扩展1:广播
并没有增加内存,只是使用时进行复制
import torch
b = torch.rand(1, 3, 1, 1)
print("b.expand([2, 3, 2, 2]).shape: ",
b.expand([2, 3, 2, 2]).shape)
print("b.expand([-1, -1, -1, -1]).shape: ",
b.expand([-1, -1, -1, -1]).shape)
print("b.expand([-1, 3, -1, -4]).shape: ",
b.expand([-1, 3, -1, -4]).shape)
repeat 扩展2:复制
复制了数据,增加了内存
import torch
b = torch.rand(1, 3, 1, 1)
print("b.repeat(2, 3, 2, 2).shape:", b.repeat(2, 3, 2, 2).shape)
print("b.repeat(2, 1, 2, 1).shape:", b.repeat(2, 1, 2, 1).shape)
.t 转置
import torch
b = torch.tensor([[1,2,3],
[4,5,6]])
print("b:\n", b)
print("b.t():\n", b.t())
transpose 维度交换
import torch
a = torch.rand(4, 3, 28, 14)
print("a.transpose(1,3).shape", a.transpose(1,3).shape)
a1 = a.transpose(1,3).contiguous().view(4, 3*28*14).view(4, 3, 28, 14)
a2 = a.transpose(1,3).contiguous().view(4, 3*28*14).view(4, 14, 28, 3).transpose(1, 3)
print("a == a1? ", torch.all(torch.eq(a, a1)))
print("a == a2? ", torch.all(torch.eq(a, a2)))
permute 维度重排
import torch
a = torch.rand(4, 3, 28, 14)
print("a.permute(3, 2, 1, 0): ", a.permute(3, 2, 1, 0).shape)
print("a.permute(1, 3, 0, 2): ", a.permute(1, 3, 0, 2).shape)
- 注意:permute不会改变底层数据存储,需要使用 contiguous 来修改存储方式。
|