官方文档:https://pytorch.org/docs/stable/generated/torch.cat.html?highlight=torch%20cat#torch.cat
形式:torch.cat(tensors, dim=0, *, out=None) → Tensor
作用:连接给定维度中给定的 seq 张量序列。所有张量必须具有相同的形状(连接维度除外)或为空。
torch.cat() 可以看作是 torch.split() 和 torch.chunk() 的逆运算。
参数:
- tensors (sequence of Tensors) – any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
- dim (int, optional) – the dimension over which the tensors are concatenated
输出:
out (Tensor, optional) – the output tensor.
使用案例:
x = torch.randn(2, 3)
torch.cat((x, x, x), 0)
torch.cat((x, x, x), 1)
|