一.为什么需要邻居采样?
在GNN领域,大图是非常常见的,但由于GPU显存的限制,大图是无法放到GPU上进行训练的。为此,可以采用邻居采样,这样一来可以将GNN扩展到大图上。在PyG中,邻居采样的方式有很多种,具体详解torch_geometric.loader 。本文以GraphSage中的邻居采样为例进行介绍,其在PyG中实现为NeighborLoader 。
NeighborSampler 也是PyG中关于GraphSage中邻居采样的实现,但已经被弃用,在未来版本中会被删除。
二.NeighborLoader 详解
2.1 GraphSage邻居采样原理
假设采样的层数为
K
K
K,每层采样的邻居数为
S
k
S_k
Sk?,GraphSage中邻居采样是这样进行的:
- 步骤一:首先给定要采样邻居的小批量节点集
B
\mathcal{B}
B;
- 步骤二:对
B
\mathcal{B}
B的
1
1
1跳(hop)邻居进行采样,然后得到
B
1
\mathcal{B}_1
B1?,然后对
B
1
\mathcal{B}_1
B1?的
1
1
1跳邻居进行采样(即最初结点集的
2
2
2跳邻居)得到
B
2
\mathcal{B}_2
B2?,如此往复进行
K
K
K次,得到最初小批量节点集相关的一个子图。
下图左是GraphSage中给出的一个2层邻居采样的示例,其中每层采样的邻居数
S
k
S_k
Sk?是相等的(图中为3)。
2.2 API介绍
PyG中,GraphSage的邻居采样实现为torch_geometric.loader.NeighborLoader ,其初始化函数参数为:
def __init__(
self,
data: Union[Data, HeteroData],
num_neighbors: NumNeighbors,
input_nodes: InputNodes = None,
replace: bool = False,
directed: bool = True,
transform: Callable = None,
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
)
常用参数说明如下:
data :要采样的图对象,可以为异构图HeteroData ,也可以为同构图Data ;num_neighbors :每个节点每次迭代(每层)采样的最大邻居数,List[int] 类型,例如[2,2] 表示采样2层,每层中每个节点最多采样2个邻居;input_nodes :从原始图中采样得到的子图中需要包含的原始图中节点索引,即2.1节中最初的
B
\mathcal{B}
B,torch.Tensor() 类型;directed :如果设置为False ,将包括所有采样节点之间的所有边;**kwargs :torch.utils.data.DataLoader 的额外参数,例如batch_size ,shuffle (具体详见该API)。
2.3 采样实践
为了可视化的美观性,本小节采用的图数据是PyG中提供的KarateClub 数据集,该数据集描述了一个空手道俱乐部会员的社交关系,节点为34名会员,如果两位会员在俱乐部之外仍保持社交关系,则在对应节点间连边,该数据集的可视化如下所示:
下面是对该数据集的加载、可视化以及邻居采样的源码:
import torch
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader
def draw(graph):
nids = graph.n_id
graph = to_networkx(graph)
for i, nid in enumerate(nids):
graph.nodes[i]['txt'] = str(nid.item())
node_labels = nx.get_node_attributes(graph, 'txt')
nx.draw_networkx(graph, labels=node_labels, node_color='#00BFFF')
plt.axis("off")
plt.show()
dataset = KarateClub()
g = dataset[0]
g.n_id = torch.arange(g.num_nodes)
for s in NeighborLoader(g, num_neighbors=[2, 2], input_nodes=torch.Tensor([14])):
draw(s)
break
在上述源码中,设置的采样层数为2层、每个节点每层采样最多采样2个邻居,采样的初始节点集为{14} ,其对应的采样结果如下所示:
从上图可以看出,在第一次迭代中,采样了节点{14} 的两个1跳邻居{32,33} ,然后在第二次迭代中对{32,33} 分别进行采样得到{2,8]} 和{18,30} 。
需要注意是通过NeighborLoader 返回的子图中,全局节点索引会映射到到与该子图对应的局部索引。因此,若要将当前采样子图中的节点映射会原来图中对应的节点,可以在原始图中创建一个属性来完成两者之间的映射,例如采样实践源码中的:
g.n_id = torch.arange(g.num_nodes)
如此以来,采样后子图中的节点同样包含n_id 属性,这样就可以将子图的节点映射回去了,上述示例中对图进行可视化便利用了这一点,其对应的映射为:
{0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
结语
PyG中对于邻居采样的实现远远不止上述这一种,具体参见如下官网资料:
|