作者:机器视觉全栈er 来源:cvtutorials.com
2.1.2 索引
筛选出符合某种条件的subtensor。
torch.where: 根据布尔变量的值选择tensor中的元素,用法如下:
torch.where(condition, x, y)
下面举个简单的例子:
>>> import torch
>>> cvtutorials = torch.randn(3, 4)
>>> threshold = torch.zeros(3, 4)
>>> cvtutorials
tensor([[-1.6981, 1.0443, 2.7922, -0.8736],
[-2.0208, -0.4815, -0.1488, -0.9714],
[ 1.1035, 0.4089, 0.6279, 2.4600]])
>>> torch.where(cvtutorials > 0, cvtutorials, threshold)
tensor([[0.0000, 1.0443, 2.7922, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[1.1035, 0.4089, 0.6279, 2.4600]])
上面torch.where函数返回tensor的某个元素的值遵循这样的选择:如果cvtutorials中的某个元素大于0,那么保留,否则设置为0,用数学公式表达如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-itiDkBU3-1645712444677)(https://gitee.com/cvtutorials/cvtutorials_picbed/raw/master/pth_hbh/image-20220224145457438.png)]
torch.index_select: 沿着某个维度,通过index对输入tensor进行筛选。用法如下:
torch.index_select(input, dim, index, *, out=None)
举个例子说明下:
>>> cvtutorials = torch.randn(2,3)
>>> cvtutorials
tensor([[-0.9935, -0.9802, -0.6104],
[ 2.6251, -1.0099, 0.4752]])
>>> indices = torch.tensor([0, 1])
>>> torch.index_select(cvtutorials, 0, indices)
tensor([[-0.9935, -0.9802, -0.6104],
[ 2.6251, -1.0099, 0.4752]])
>>> torch.index_select(cvtutorials, 1, indices)
tensor([[-0.9935, -0.9802],
[ 2.6251, -1.0099]])
torch.masked_select: 根据设置的mask,返回一个一维的tensor(向量)。用法如下:
torch.masked_select(input, mask, *, out=None)
举个简单的例子:
>>> cvtutorials = torch.randn(2, 3)
>>> cvtutorials
tensor([[ 1.1016, -1.5259, 1.1065],
[ 0.4838, -0.5521, 0.1556]])
>>> mask = torch.tensor([[False, True, True], [True, False, False]])
>>> mask
tensor([[False, True, True],
[ True, False, False]])
>>> torch.masked_select(cvtutorials, mask)
tensor([-1.5259, 1.1065, 0.4838])
从中可以看出,根据mask对输入tensor相应位置的元素进行筛选,mask某位置为True,则取出tensor相应位置的元素,否则,不取出。
还有一点,mask的shape不一定和tensor一样,但是需要broadcast到tensor上,例如:
>>> cvtutorials = torch.randn(2, 3)
>>> cvtutorials
tensor([[ 0.8686, 0.0910, 1.8702],
[ 1.8140, -1.0902, 0.7051]])
>>> mask = torch.tensor([[False, True, True]])
>>> mask
tensor([[False, True, True]])
>>> torch.masked_select(cvtutorials, mask)
tensor([ 0.0910, 1.8702, -1.0902, 0.7051])
|