IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> PyG教程(7):剖析邻域聚合 -> 正文阅读

[Python知识库]PyG教程(7):剖析邻域聚合

一.前言

上篇文章《PyG教程(6):自定义消息传递网络》主要介绍了消息传递GNN的大致框架。本文主要聚焦于消息传播中的邻域聚合,本文将介绍PyG是如何将节点的邻居的消息聚合到节点本身的。

二.PyG中的邻域聚合

PyG中邻域聚合是通过aggregate(inputs, index)函数来完成的,该函数的第一个参数inputs为消息构建函数message()构建的消息,该函数还存在一个参数index,这个参数对于消息聚合是十分关键的,它指示了inputs中每条消息属于哪个节点的邻域。下图便很好的解释了PyG中的消息聚合:

scatter

上述栗子中展示的是包含4个顶点、8条边的graph,其中input为在8条边上传播的消息、index为各条边上消息的归属,即目标节点的索引。通过index,可以将属于同一个节点邻域的消息聚合到一起,常见的聚合包括summeanmeanmulmin等。

在PyG中通过scatter函数来实现上述过程,查看MessagePassing的源码,可以看到其aggregate函数的定义如下:

def aggregate(self, inputs: Tensor, index: Tensor,
              ptr: Optional[Tensor] = None,
              dim_size: Optional[int] = None) -> Tensor:
    r"""
    注释太长略
     """
    if ptr is not None:
        ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
        return segment_csr(inputs, ptr, reduce=self.aggr)
    else:
        return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
                       reduce=self.aggr)

aggregate函数中scatter函数源码为:

def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
            out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
            reduce: str = "sum") -> torch.Tensor:
    r"""
	注释太长略
    """
    if reduce == 'sum' or reduce == 'add':
        return scatter_sum(src, index, dim, out, dim_size)
    if reduce == 'mul':
        return scatter_mul(src, index, dim, out, dim_size)
    elif reduce == 'mean':
        return scatter_mean(src, index, dim, out, dim_size)
    elif reduce == 'min':
        return scatter_min(src, index, dim, out, dim_size)[0]
    elif reduce == 'max':
        return scatter_max(src, index, dim, out, dim_size)[0]
    else:
        raise ValueError

其中便包含了前面提到的5种聚合方式。对于这些聚合方式,只需要在继承MessagePassing类时,通过super().__init__来向该类传递参数aggr参数的值即可。

三.torch_scatter模块

若用户需要自定义消息聚合,则在重写的aggregate()函数中,同样可以使用MessagePassing中的scatter函数,只需要导入torch_scatter模块即可。

torch_scatter模块中也实现了scatter函数,其声明如下:

scatter: (src: Tensor, index: Tensor, dim: int = -1, out: Tensor | None = None, dim_size: int | None = None, reduce: str = "sum") -> Tensor

常用参数说明

参数说明
src每条边上的源节点生成的消息
index指示每条边上消息需要聚合到哪个节点上
dim指示沿着那个维度(轴)应用index进行聚合
reduce聚合操作,包括summulmeanminmax

注意,torch_scatter也为上述的几种聚合单独提供了API:

torch_scatter.scatter_add()
torch_scatter.scatter_max()
torch_scatter.scatter_mean()
torch_scatter.scatter_min()
torch_scatter.scatter_mul()

为了方便理解,下面给出一个栗子,假设存在一个包含3个顶点、6条边的图:

scatter_example

假设0、1、2三个顶点生成的消息分别为1、2、3,则图中6条边的消息inputs和相应的index构造如下:

inputs = torch.tensor([[1], [1], [2], [2], [3], [3]])
index = torch.tensor([1, 2, 0, 2, 0, 1])

应用torch_scatter.scatter()函数的结果如下:

out = torch_scatter.scatter(src=inputs, index=index, dim=0, reduce="sum")
print(out)
"""
tensor([[5],
        [4],
        [3]])
"""

可以看到节点0接受来自节点1,2的消息得到2+3=5,节点1接受来自节点0,2的消息得到1+3=4,而节点2接受来自节点0,1的消息得到1+2=3

四.结语

参考资料:

通过本文可以加深对PyG中消息聚合过程的理解,这将有助于更好的自定义GNN模型。以上便是本文的全部内容,若有任何错误,请批评指正。

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-06-29 19:00:05  更:2022-06-29 19:03:40 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/27 2:01:30-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计