1. 问题:
大部分帖子和一些典型的介绍numpy、pytorch的书籍对此部分并没有详细的介绍,仅仅简单地在np.cat()或torch.stack()等中提到当连接轴指定为0,或1时按照某某连接、排列。然而,当连接轴值较大时(如3,4,5),笔者发现,这些函数的输出并不是大多数材料所写的那样。笔者甚至看不出任何规律,为探寻其轴操作的原理,做了如下实验。
2 归纳:
定义两个初始矩阵: 执行pytorch中的torch.stack([A, B], 2)操作,因为[A, B]可理解为将A,B连接,即
torch.stack([A, B], 2)的执行结果为:
将R 中的各个元素改写为C 中对应值的索引,可得:
根据索引的排布规则,可以看出,R是按照轴序为(1,2,0,3)的顺序排列 的。所以,实际上torch.stack()叠加其实是按照改变元素索引的方式排列的,而不是按照什么“行”,“列”排列的。归纳可推出torch.stack()按轴连接、排列的步骤如下:
-
首先得到排列轴的顺序。按下图所示,当指定按照m轴为排列轴时,将0轴挪入到对应的索引位置m,其他轴自动补齐。 -
计算出[A, B]中所有元素的索引坐标,把原始轴序列(0,1,2,3) 调整为排列轴序列(如 (1,2,0,3)),调整各元素的索引。 -
根据各元素的索引重新排列矩阵。
3 验证:
定义两个4阶张量:
[A2, B2]可表达为5阶张量:
torch.stack([A2, B2], 2)的结果为:
按照排列准则,C 中的轴排列(0,1,2,3,4)变为(1,2,0,3,4) 即:
对比:
验证归纳出的排列原理正确!
|