1、shape与size()
print('*' * 100)
print('查看数据形状')
data1 = torch.randn((7,3,32,32))
print("shape:",data1.shape)
print("size:",data1.size())
查看数据形状 shape: torch.Size([7, 3, 32, 32]) size: torch.Size([7, 3, 32, 32])
2、squeeze与unsqueeze
print('*' * 100)
print('压缩')
data2 = torch.randn((3,1,32,1,32))
print("data2_shape:",data2.size())
data2_y = torch.squeeze(data2)
print("squeeze:",data2_y.size())
data2_y2 = torch.squeeze(data2,dim=0)
print("squeeze_0:",data2_y2.size())
data2_y3 = torch.squeeze(data2, dim=1)
print("squeeze_1:",data2_y3.size())
data2_y4 = torch.squeeze(data2_y3, dim=2)
print("squeeze_2:",data2_y4.size())
print('*' * 100)
print("增维")
data3 = torch.randn((3,32,32))
print("data3_shape",data3.shape)
data3_1 = torch.unsqueeze(data3,dim=0)
print("unsqueeze_0",data3_1.shape)
data3_2 = torch.unsqueeze(data3_1,dim=-1)
print("unsqueeze_-1",data3_2.shape)
压缩 data2_shape: torch.Size([3, 1, 32, 1, 32]) squeeze: torch.Size([3, 32, 32]) squeeze_0: torch.Size([3, 1, 32, 1, 32]) squeeze_1: torch.Size([3, 32, 1, 32]) squeeze_2: torch.Size([3, 32, 32])
增维 data3_shape torch.Size([3, 32, 32]) unsqueeze_0 torch.Size([1, 3, 32, 32]) unsqueeze_-1 torch.Size([1, 3, 32, 32, 1])
Process finished with exit code 0
3、广播机制
print('*'*100)
print("广播机制")
data4 = torch.randn((8,8))
print("data4_shape:",data4.size())
data5 = torch.randn(1)
print("data5_shape",data5.size())
data6 = data4 + data5
print("(data4+data5)_shape",data6.size())
广播机制 data4_shape: torch.Size([8, 8]) data5_shape torch.Size([1]) (data4+data5)_shape torch.Size([8, 8])
4、view与reshape
print('*'*100)
print("改变形状")
x = torch.randn(4, 4)
print('x_shape:',x.size())
y = x.view(16)
print("y_shape:",y.size())
z = x.view(-1,8)
print("z_shape:",z.size())
a = torch.randn(1, 2, 3, 4)
print("a_shape:",a.size())
b = a.transpose(1,2)
print("b_shape:",b.size())
c = a.view(1, 3, 2, 4)
print("c_shape:",c.size())
print("a_b:",torch.equal(a,b))
print("b_c:",torch.equal(b,c))
print("a_c",torch.equal(a,c))
data7 = torch.arange(12).reshape(3,4)
data8 = torch.arange(12).view(3,4)
print("reshape:",data7)
print("view:",data8)
改变形状 x_shape: torch.Size([4, 4]) y_shape: torch.Size([16]) z_shape: torch.Size([2, 8]) a_shape: torch.Size([1, 2, 3, 4]) b_shape: torch.Size([1, 3, 2, 4]) c_shape: torch.Size([1, 3, 2, 4]) a_b: False b_c: False a_c False
reshape: tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) view: tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]])
5、permute
a2 = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
print("a2_shape:",a2.shape)
b2 = a2.view(2,3,2)
print("b2_shape:",b2.shape)
print("b2:",b2)
a3 = torch.tensor([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])
print("a3_shape:",a3.shape)
b3 = a3.permute(0,2,1)
print("b3_shape:",b3.shape)
print("b3:",b3)
a4 = torch.tensor([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])
b4 = a4.transpose(1,2)
print(b4)
a2_shape: torch.Size([2, 2, 3]) b2_shape: torch.Size([2, 3, 2]) b2: tensor([[[ 1, 2], [ 3, 4], [ 5, 6]], [[ 7, 8], [ 9, 10], [11, 12]]]) a3_shape: torch.Size([2, 2, 3]) b3_shape: torch.Size([2, 3, 2]) b3: tensor([[[ 1, 4], [ 2, 5], [ 3, 6]], [[ 7, 10], [ 8, 11], [ 9, 12]]]) tensor([[[ 1, 4], [ 2, 5], [ 3, 6]], [[ 7, 10], [ 8, 11], [ 9, 12]]])
6、torch的连续性
a5 = torch.arange(12).reshape(3,4)
print("a5:",a5)
print("查看a5内存存储布局",[i for i in a5.storage()])
a5_y = a5.permute(1,0)
print("permute:",a5_y)
print("查看a5_permute内存存储布局", [i for i in a5_y.storage()])
a5_z =a5_y.is_contiguous()
print("permute()存储是否连续:",a5_z)
a5_f =a5_y.contiguous()
a5_g = a5_f.is_contiguous()
print("contiguous()存储是否连续:",a5_g)
print("查看a5_permute_contiguous内存存储布局", [i for i in a5_f.storage()])
a5_h = a5_f.view(3,4)
print("view:",a5_h)
a5_l = a5_h.is_contiguous()
print("view()存储是否连续:",a5_l)
print("查看a5_permute_contiguous_view内存存储布局", [i for i in a5_h.storage()])
a5: tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) 查看a5内存存储布局 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] permute: tensor([[ 0, 4, 8], [ 1, 5, 9], [ 2, 6, 10], [ 3, 7, 11]]) 查看a5_permute内存存储布局 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] permute()存储是否连续: False contiguous()存储是否连续: True 查看a5_permute_contiguous内存存储布局 [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] view: tensor([[ 0, 4, 8, 1], [ 5, 9, 2, 6], [10, 3, 7, 11]]) view()存储是否连续: True 查看a5_permute_contiguous_view内存存储布局 [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
总结:
1、reshape 相当于 先contiguous整理内存后 + view
2、permute = transpose().transpose()… 多次.transpose
3、一般需要将tensor展开时使用view(),而在需要转置时使用permute()。
4、在使用permute后,要继续使用view()的话,需要先用contiguous整理内存后,才能使用view,或者直接使用reshape()。
5、view是对原始数据进行操作,contiguous后,如果原来是行存储则存储区不变,如果是列存储,就会开辟了一个新的存储区。
|