一、cat函数
torch.catimport torch
a=torch.randn(3,4)
b=torch.randn(2,4)
print("a:")
print(a)
print("b:")
print(b)
print("拼接结果:")
print(torch.cat([a,b],dim=0))
a:
tensor([[ 2.1754, 1.4698, 1.4103, -0.2498],
[ 0.3248, -1.9372, -0.9310, -0.3833],
[-0.3603, -0.0271, -0.1942, -0.0345]])
b:
tensor([[-0.3917, 0.1332, -1.0066, 0.3633],
[-0.2378, 0.5224, 1.1371, 0.6401]])
拼接结果:
tensor([[ 2.1754, 1.4698, 1.4103, -0.2498],
[ 0.3248, -1.9372, -0.9310, -0.3833],
[-0.3603, -0.0271, -0.1942, -0.0345],
[-0.3917, 0.1332, -1.0066, 0.3633],
[-0.2378, 0.5224, 1.1371, 0.6401]])
二、stack函数
创建一个新的维度。
要求:两个tensor拼接前“拼接维度”的形状要完全一致
torch.stack要求:两个tensor拼接前的形状完全一致import torch
a=torch.randn(3,4)
b=torch.randn(3,4)
print("a: ")
print(a)
print("b: ")
print(b)
c=torch.stack([a,b],dim=0)
d=torch.stack([a,b],dim=1)
print("d: ")
print(d)
这里的关键词参数dim的理解和cat方法中有些区别。
cat方法中可以理解为原tensor的维度,dim=0,就是沿着原来的0轴进行拼接,dim=1,就是沿着原来的1轴进行拼接。
stack方法中的dim则是指向新增维度的位置,dim=0,就是在新形成的tensor的维度的第0个位置新插入维度
三、split函数
split是根据长度去拆分tensor
import torch
a=torch.randn(3,4)
print('a :')
print(a)
print("按维度0拆分 : ")print(a.split([1,2],dim=0))
print("按维度1拆分 : ")
print(a.split([2,2],dim=1))
四、chunk函数
chunk可以理解为均等分的split,但是当维度长度不能被等分份数整除时,虽然不会报错,但可能结果与预期的不一样,建议只在可以被整除的情况下运用
import torch
a=torch.randn(4,6)
print("a :")
print(a)
print(a.chunk(2,dim=1))
a :
tensor([[-0.4875, 1.4914, 0.2244, -0.5883, -0.5951, -0.4857],
[-0.1344, -0.6973, -0.2042, 2.5817, -0.7972, -0.6522],
[ 1.4379, -0.1185, 0.4457, -1.1168, 1.0184, -0.5088],
[-0.7692, 1.4040, -0.2799, 1.1515, 0.2329, 0.4926]])
(tensor([[-0.4875, 1.4914, 0.2244],
[-0.1344, -0.6973, -0.2042],
[ 1.4379, -0.1185, 0.4457],
[-0.7692, 1.4040, -0.2799]]), tensor([[-0.5883, -0.5951, -0.4857],
[ 2.5817, -0.7972, -0.6522],
[-1.1168, 1.0184, -0.5088],
[ 1.1515, 0.2329, 0.4926]]))
参考资料: Pytorch:Tensor的合并与分割 pytorch tensor 的拼接和拆分
|