torch.cat是将两个张量(tensor) 拼接在一起,cat是concatnate的意思, 即拼接,联系在一起。
使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。y即:当dim=0时,按行拼接;当dim=1时,按列拼接。
例如:
In [11]: import torch as tr
In [12]: A=tr.ones(2,3)
In [13]: A
Out[13]:
tensor([[1., 1., 1.],
[1., 1., 1.]])
In [14]: B=2*tr.ones(4,3)
In [15]: B
Out[15]:
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
In [16]: C=tr.cat((A,B),0)
In [17]: C
Out[17]:
tensor([[1., 1., 1.],
[1., 1., 1.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
In [18]: C=tr.cat((A,B),1)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [18], in <cell line: 1>()
----> 1 C=tr.cat((A,B),1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 4 for tensor number 1 in the list.
??????????????
|