Indexing: dim 0 first
a = torch.rand(4,3,28,28)
print(a.dim())
print(a.shape)
4
torch.Size([4, 3, 28, 28])
print(a[0].shape)
print(a[0, 0].shape)
print(a[0, 0, 0].shape)
print(a[0, 0, 0, 0].shape)
…
print(a[0, ...].shape)
print(a[0, ..., 0].shape)
print(a[..., :2].shape)
select first/last N
print(a[:2].shape)
print(a[:2, :1, :, :].shape)
print(a[1:, :, :, :].shape)
select by steps
print(a[:, :, 0:28:2, 0:28:2].shape)
print(a[:, :, ::2, ::2].shape)
select by specific index
print(a.index_select(0, torch.tensor([0,2])).shape)
print(a.index_select(1, torch.arange(2)).shape)
select by mask
a = torch.rand(3,4)
print(a)
tensor([[0.9796, 0.9025, 0.2744, 0.4932],
[0.6778, 0.5818, 0.7009, 0.6437],
[0.2674, 0.8005, 0.6140, 0.6765]])
mask = torch.tensor([[False, False, False, True],
[True, False, False, False],
[False, True, False, False]])
a = torch.masked_select(a, mask)
print(a)
print(a.shape)
|