作者:机器视觉全栈er 来源:cvtutorials.com
2.1.3 分
分就是将tensor拆分或者只看tensor的某个部分。
torch.chunk:将一个tensor分解成多个块。用法如下:
torch.chunk(input, chunks, dim=0)
如果chunk不能被输入的tensor的dim方向上的整除的话,最后一个块和其他块的大小不一样,举个例子来说明下:
>>> cvtutorials = torch.randn(7, 7)
>>> cvtutorials
tensor([[-1.1886, -1.4592, -0.2017, -1.7050, 2.0357, -0.3754, -0.6668],
[ 0.5493, -0.8917, -0.3982, -2.4488, 1.8564, -0.6771, 0.5145],
[-0.9599, 0.7301, -1.1945, -0.4188, -0.4440, 1.5929, -0.8686],
[ 0.0075, 0.3351, 0.2337, 0.6663, 0.7913, 0.0717, 1.0995],
[ 0.6604, 0.6436, 0.2717, -1.0651, 0.6586, -1.4068, -1.1303],
[ 0.3116, 0.7408, -0.3726, 1.2334, -0.6076, -1.3587, -1.8686],
[-0.8115, -1.4766, 1.4367, 2.1566, -1.1522, -0.8364, -1.0398]])
>>> torch.chunk(cvtutorials, 3, 0)
(tensor([[-1.1886, -1.4592, -0.2017, -1.7050, 2.0357, -0.3754, -0.6668],
[ 0.5493, -0.8917, -0.3982, -2.4488, 1.8564, -0.6771, 0.5145],
[-0.9599, 0.7301, -1.1945, -0.4188, -0.4440, 1.5929, -0.8686]]),
tensor([[ 0.0075, 0.3351, 0.2337, 0.6663, 0.7913, 0.0717, 1.0995],
[ 0.6604, 0.6436, 0.2717, -1.0651, 0.6586, -1.4068, -1.1303],
[ 0.3116, 0.7408, -0.3726, 1.2334, -0.6076, -1.3587, -1.8686]]),
tensor([[-0.8115, -1.4766, 1.4367, 2.1566, -1.1522, -0.8364, -1.0398]]))
这里的dim默认是0,表示将行数进行分割,如果是1的话,表示将列数进行分割。
torch.tensor_split: 将输入的tensor分解成多个tensor,用法如下:
torch.tensor_split(input, indices_or_sections, dim=0)
下面我们分两种情况对torch.tensor_split举例说明:第一种情况是indices_or_sections参数接收的是一个整数,那么输入tensor会被分成多个同样大小的块,如果无法整除,最后一个维度会小点,举例如下:
>>> cvtutorials = torch.arange(5)
>>> cvtutorials
tensor([0, 1, 2, 3, 4])
>>> torch.tensor_split(cvtutorials, 2)
(tensor([0, 1, 2]), tensor([3, 4]))
>>> cvtutorials = torch.randn(3, 4)
>>> cvtutorials
tensor([[ 0.2840, -0.3296, -0.6659, 0.4259],
[-0.7710, 1.1885, -1.3157, 0.2106],
[ 0.4975, -0.2922, -0.5841, -1.3325]])
>>> torch.tensor_split(cvtutorials, 2)
(tensor([[ 0.2840, -0.3296, -0.6659, 0.4259],
[-0.7710, 1.1885, -1.3157, 0.2106]]), tensor([[ 0.4975, -0.2922, -0.5841, -1.3325]]))
第二种情况是indices_or_sections参数接收的是一个list,那么输入的tensor会按照列表的数值分成多个块,比起第一种情况定义更加灵活,举例如下:
>>> cvtutorials = torch.randn(3, 6)
>>> cvtutorials
tensor([[ 0.5489, -0.9790, -0.2129, 1.1956, 0.6439, -0.7969],
[ 1.1158, -0.9655, -0.3893, -0.4356, 0.0512, -0.8653],
[-0.5969, -1.0525, -0.3342, -1.1016, 0.6439, -1.5340]])
>>> torch.tensor_split(cvtutorials, (1, 5), dim=1)
(tensor([[ 0.5489],
[ 1.1158],
[-0.5969]]),
tensor([[-0.9790, -0.2129, 1.1956, 0.6439],
[-0.9655, -0.3893, -0.4356, 0.0512],
[-1.0525, -0.3342, -1.1016, 0.6439]]),
tensor([[-0.7969],
[-0.8653],
[-1.5340]]))
这里要注意:(1, 5)代表的意思是将输入的tensor沿着dim=1的方向,分为[:1], [1:5], [5:]三个块。
torch.dsplit: 将一个高维tensor进行分割,dsplit的全称是depthwise split,用法如下:
torch.dsplit(input, indices_or_sections)
这个函数的用法和torch.tensor_split(input, indices_or_sections, dim=2)等价。
torch.hsplit: 将一个tensor进行分割,hsplit的全称是horizontally split,用法如下:
torch.hsplit(input, indices_or_sections)
如果输入tensor是一维的,等价于torch.tensor_split(input, indices_or_sections, dim=0),如果输入tensor是高维的,等价于torch.tensor_split(input, indices_or_sections, dim=1)。
torch.vsplit:将一个tensor进行分割,vsplit的全称是vertically split,用法如下:
torch.vsplit(input, indices_or_sections)
这个函数等价于torch.tensor_split(input, indices_or_sections, dim=0)。
torch.split: 将tensor分成几个块,用法如下:
torch.split(tensor, split_size_or_sections, dim=0)
如果split_size_or_sections输入的是整数,就会将输入的tensor均匀的分成几个块(最后一个块大小不定),如果是个列表,就会按照列表的数值对输入tensor进行拆分,举个例子:
>>> cvtutorials = torch.randn(3, 3)
>>> cvtutorials
tensor([[ 0.4390, 0.7209, -0.6984],
[-0.1046, 0.7321, 0.1514],
[ 0.2743, -0.5920, 0.8933]])
>>> torch.split(cvtutorials, 2)
(tensor([[ 0.4390, 0.7209, -0.6984],
[-0.1046, 0.7321, 0.1514]]), tensor([[ 0.2743, -0.5920, 0.8933]]))
>>> torch.split(cvtutorials, [1, 2])
(tensor([[ 0.4390, 0.7209, -0.6984]]), tensor([[-0.1046, 0.7321, 0.1514],
[ 0.2743, -0.5920, 0.8933]]))
torch.narrow:返回输入tensor的缩减版,用法如下:
torch.narrow(input, dim, start, length)
下面举个例子对torch.narrow的用法进行说明:
>>> import torch
>>> cvtutorials = torch.randn(3, 3)
>>> cvtutorials
tensor([[-0.0255, -1.3090, 2.0309],
[ 0.1062, 0.5078, 0.9572],
[ 1.0461, 0.3789, -1.2349]])
>>> torch.narrow(cvtutorials, dim=0, start=0, length=2)
tensor([[-0.0255, -1.3090, 2.0309],
[ 0.1062, 0.5078, 0.9572]])
dim为0的时候,表示沿着行进行缩减,start表示索引,从0开始,length表示截取的行数,即从cvtutorials中沿着行截取,从第“0”行开始,截取两行。
torch.unbind:将tensor的某个维度去掉,dim的默认值为0,指的是沿着“行”进行切片,用法如下:
torch.unbind(input, dim=0)
其实,unbind可以看作是chunk函数的一个特例。下面举个例子解释下unbind的用法:
>>> cvtotrials = torch.randn(3, 3)
>>> torch.unbind(cvtutorials, 0)
(tensor([-1.1886, -1.4592, -0.2017, -1.7050, 2.0357, -0.3754, -0.6668]),
tensor([ 0.5493, -0.8917, -0.3982, -2.4488, 1.8564, -0.6771, 0.5145]), tensor([-0.9599, 0.7301, -1.1945, -0.4188, -0.4440, 1.5929, -0.8686]),
tensor([0.0075, 0.3351, 0.2337, 0.6663, 0.7913, 0.0717, 1.0995]),
tensor([ 0.6604, 0.6436, 0.2717, -1.0651, 0.6586, -1.4068, -1.1303]),
tensor([ 0.3116, 0.7408, -0.3726, 1.2334, -0.6076, -1.3587, -1.8686]), tensor([-0.8115, -1.4766, 1.4367, 2.1566, -1.1522, -0.8364, -1.0398]))
|