普通广播
>>> import torch
>>> a = torch.tensor([[1,2,3],[4,5,6]])
>>> b = torch.full_like(a,0)
>>> c = torch.tensor([[0,0,1],[1,0,1]])
>>> c[:,0]
tensor([0, 1])
>>> b[range(n),c[:,0]] = 1
>>> b
tensor([[1, 0, 0],
[0, 1, 0]])
为什么会出现这样的结果?
赋值语句的意思是:
- range(n)表示对b的所有行进行赋值操作
- c[:,0]] 表示执行赋值操作的b的列索引,[0, 1] 表示第一行对索引为0的列进行操作(赋值为1);第二行对索引为1的列进行操作(赋值为1)
- 最右边的1表示对应索引位置所赋的值
scatter函数
import torch
label = torch.zeros(3, 6)
print("label:",label)
a = torch.ones(3,5)
b = [[0,1,2],[0,1,3],[1,2,3]]
print(a)
label.scatter_(1,torch.LongTensor(b),a)
print("new_label: ",label)
label:
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
new_label:
tensor([[1., 1., 1., 0., 0., 0.],
[1., 1., 0., 1., 0., 0.],
[0., 1., 1., 1., 0., 0.]])
举例:
>>> b = torch.full_like(a,0)
>>> b
tensor([[0, 0, 0],
[0, 0, 0]])
>>> c = torch.tensor([[0,0],[1,0]])
>>> c
tensor([[0, 0],
[1, 0]])
>>> b.scatter_(1,torch.LongTensor(c),1)
>>> b
tensor([[1, 0, 0],
[1, 1, 0]])
感谢!https://blog.csdn.net/qq_41368074/article/details/106986753
|