一.前言
上篇文章《PyG教程(6):自定义消息传递网络》主要介绍了消息传递GNN的大致框架。本文主要聚焦于消息传播中的邻域聚合,本文将介绍PyG是如何将节点的邻居的消息聚合到节点本身的。
二.PyG中的邻域聚合
PyG中邻域聚合是通过aggregate(inputs, index) 函数来完成的,该函数的第一个参数inputs 为消息构建函数message() 构建的消息,该函数还存在一个参数index ,这个参数对于消息聚合是十分关键的,它指示了inputs 中每条消息属于哪个节点的邻域。下图便很好的解释了PyG中的消息聚合:
上述栗子中展示的是包含4个顶点、8条边的graph,其中input 为在8条边上传播的消息、index 为各条边上消息的归属,即目标节点的索引。通过index ,可以将属于同一个节点邻域的消息聚合到一起,常见的聚合包括sum 、mean 、mean 、mul 和min 等。
在PyG中通过scatter 函数来实现上述过程,查看MessagePassing 的源码,可以看到其aggregate 函数的定义如下:
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
r"""
注释太长略
"""
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)
aggregate 函数中scatter 函数源码为:
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
注释太长略
"""
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError
其中便包含了前面提到的5种聚合方式。对于这些聚合方式,只需要在继承MessagePassing 类时,通过super().__init__ 来向该类传递参数aggr 参数的值即可。
三.torch_scatter 模块
若用户需要自定义消息聚合,则在重写的aggregate() 函数中,同样可以使用MessagePassing 中的scatter 函数,只需要导入torch_scatter 模块即可。
在torch_scatter 模块中也实现了scatter 函数,其声明如下:
scatter: (src: Tensor, index: Tensor, dim: int = -1, out: Tensor | None = None, dim_size: int | None = None, reduce: str = "sum") -> Tensor
常用参数说明:
参数 | 说明 |
---|
src | 每条边上的源节点生成的消息 | index | 指示每条边上消息需要聚合到哪个节点上 | dim | 指示沿着那个维度(轴)应用index 进行聚合 | reduce | 聚合操作,包括sum 、mul 、 mean 、 min 和 max |
注意,torch_scatter 也为上述的几种聚合单独提供了API:
torch_scatter.scatter_add()
torch_scatter.scatter_max()
torch_scatter.scatter_mean()
torch_scatter.scatter_min()
torch_scatter.scatter_mul()
为了方便理解,下面给出一个栗子,假设存在一个包含3个顶点、6条边的图:
假设0、1、2三个顶点生成的消息分别为1、2、3,则图中6条边的消息inputs 和相应的index 构造如下:
inputs = torch.tensor([[1], [1], [2], [2], [3], [3]])
index = torch.tensor([1, 2, 0, 2, 0, 1])
应用torch_scatter.scatter() 函数的结果如下:
out = torch_scatter.scatter(src=inputs, index=index, dim=0, reduce="sum")
print(out)
"""
tensor([[5],
[4],
[3]])
"""
可以看到节点0接受来自节点1,2的消息得到2+3=5 ,节点1接受来自节点0,2的消息得到1+3=4 ,而节点2接受来自节点0,1的消息得到1+2=3 。
四.结语
参考资料:
通过本文可以加深对PyG中消息聚合过程的理解,这将有助于更好的自定义GNN模型。以上便是本文的全部内容,若有任何错误,请批评指正。
|