一、Scatter函数的理解
??这个函数是用一个src的源张量或者标量以及索引来修改另一个张量,常用来做one-hot编码。这个函数主要有三个参数 scatter(dim,index,src)
- dim:沿着哪个维度来进行索引(一会儿举个例子就明白了)
- index:用来进行索引的张量
- src:源张量或者标量
self[index[i][j][k]][j][k] = src[i][j][k]
self[i][index[i][j][k]][k] = src[i][j][k]
self[i][j][index[i][j][k]] = src[i][j][k]
self[index[x][y]][y]=src[x][y]
self[x][index[x][y]]=src[x][y]
二、Scatter函数进行独热编码(one-hot)
import torch
index = torch.arange(5).unsqueeze(1)
'''
index =
tensor([[0],
[1],
[2],
[3],
[4]])
'''
one_hot = torch.zeros(5,5).scatter_(1, index, 1)
'''
one_hot =
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
'''
具体的可以参考这篇博客:PyTorch笔记之 scatter() 函数
|