主要针对不知道函数结果的操作进行测试
1、函数测试
1.1、测试 .sum(dim=(m,n))
f = torch.arange(4 * 5 * 6).view(1,1, 4, 5, 6)
f.shape
Out[19]: torch.Size([1, 1, 4, 5, 6])
f
Out[20]:
tensor([[[[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[ 12, 13, 14, 15, 16, 17],
[ 18, 19, 20, 21, 22, 23],
[ 24, 25, 26, 27, 28, 29]],
[[ 30, 31, 32, 33, 34, 35],
[ 36, 37, 38, 39, 40, 41],
[ 42, 43, 44, 45, 46, 47],
[ 48, 49, 50, 51, 52, 53],
[ 54, 55, 56, 57, 58, 59]],
[[ 60, 61, 62, 63, 64, 65],
[ 66, 67, 68, 69, 70, 71],
[ 72, 73, 74, 75, 76, 77],
[ 78, 79, 80, 81, 82, 83],
[ 84, 85, 86, 87, 88, 89]],
[[ 90, 91, 92, 93, 94, 95],
[ 96, 97, 98, 99, 100, 101],
[102, 103, 104, 105, 106, 107],
[108, 109, 110, 111, 112, 113],
[114, 115, 116, 117, 118, 119]]]]])
g = torch.arange(6).view(1,1, 1, 1, 6)
g.shape
Out[22]: torch.Size([1, 1, 1, 1, 6])
g
Out[23]: tensor([[[[[0, 1, 2, 3, 4, 5]]]]])
k = f*g
k.shape
Out[25]: torch.Size([1, 1, 4, 5, 6])
k
Out[26]:
tensor([[[[[ 0, 1, 4, 9, 16, 25],
[ 0, 7, 16, 27, 40, 55],
[ 0, 13, 28, 45, 64, 85],
[ 0, 19, 40, 63, 88, 115],
[ 0, 25, 52, 81, 112, 145]],
[[ 0, 31, 64, 99, 136, 175],
[ 0, 37, 76, 117, 160, 205],
[ 0, 43, 88, 135, 184, 235],
[ 0, 49, 100, 153, 208, 265],
[ 0, 55, 112, 171, 232, 295]],
[[ 0, 61, 124, 189, 256, 325],
[ 0, 67, 136, 207, 280, 355],
[ 0, 73, 148, 225, 304, 385],
[ 0, 79, 160, 243, 328, 415],
[ 0, 85, 172, 261, 352, 445]],
[[ 0, 91, 184, 279, 376, 475],
[ 0, 97, 196, 297, 400, 505],
[ 0, 103, 208, 315, 424, 535],
[ 0, 109, 220, 333, 448, 565],
[ 0, 115, 232, 351, 472, 595]]]]])
n = k.sum(dim=(2,3))
n.shape
Out[31]: torch.Size([1, 1, 6])
n
Out[32]: tensor([[[ 0, 1160, 2360, 3600, 4880, 6200]]])
1.2、测试.sum(dim=-1)
程序接上面程序执行 所以dim取值就是要消除的维度,dim取-1就是要消除最后一个维度。 对于上面的有5个维度的矩阵,维度索引 0,1,2,3,4。 先执行消除维度2,3,保留的维度0,1,4,然后再执行消除维度-1,即消除维度4,则剩下的维度是0,1.
o = n.sum(dim=-1)
o.shape
Out[34]: torch.Size([1, 1])
o
Out[35]: tensor([[18200]])
|