建立消息传递网络MPNN
本文主要从实现层面讲解代码。
在Graph中,如果将卷积这一操作推广到其他域中时,往往用邻域聚合或消息传递来表示。 接下来有定义几个Notation:
x
i
(
k
)
\textbf{x}_i^{(k)}
xi(k)?表示第
k
k
k次迭代的节点
i
i
i的节点特征;
e
j
,
i
\textbf{e}_{j,i}
ej,i?表示从节点
j
j
j到节点
i
i
i的(可选的)边缘特征;
i
i
i被约定为单向边的目的节点;
j
j
j被约定为单向边的源节点;
N
(
i
)
\mathcal{N}(i)
N(i)表示节点
i
i
i的邻居节点(即,与
i
i
i有直接连边的节点); 简单来讲,这个GNN模型可以被表征为:
x
i
(
k
)
=
γ
(
k
)
(
x
i
(
k
?
1
)
,
j
∈
N
(
i
)
?
(
x
i
(
k
?
1
)
,
x
j
(
k
?
1
)
,
e
j
,
i
)
)
\textbf{x}_i^{(k)}=\gamma^{(k)}(\textbf{x}_i^{(k-1)},\boxed{}_{j \in \mathcal{N}(i)}\phi(\textbf{x}_i^{(k-1)},\textbf{x}_j^{(k-1)},\textbf{e}_{j,i}))
xi(k)?=γ(k)(xi(k?1)?,?j∈N(i)??(xi(k?1)?,xj(k?1)?,ej,i?)) 其中,
\boxed{}
?(方框操作)是代表了一种可导的(differentiable)且置换不变的(permutation)函数,在Pytorch Geometric中,提供了sum,mean,max三种操作,接下来这种“方框操作”将被称为“聚合函数”。
γ
\gamma
γ和
?
\phi
?是两种不同的可导函数,以此来进行所谓的“特征提取”,常见的比如简单的MLP。
Message Passing基类
在PyTorch Geometric中,提供了Message Passing(MP)基类来帮助我们构建MPNN,官方文档给出了一个很好的范例:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add')
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
为了方便,我们一部分一部分的来看。
init
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add')
self.lin = torch.nn.Linear(in_channels, out_channels)
利用MP基类,我们建立属于自己的类GCNConv,并对其进行初始化操作,包括设置参数,而MP基类中可被用户修改定义的参数包括: MessagePassing(aggr=“add”, flow=“source_to_target”, node_dim=-2) 聚合函数(“add”, “mean” or “max”);默认add,也是最常用的。 信息传递方向(“source_to_target” or “target_to_source”).;默认前一个。 node_dim个人认为不太用,如需要可参考[1]。 此外,在这里也可以定义我们需要的其它库函数,例如线性函数、MLP、GRU等等。
forward
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm)
我们对消息传递的主要操作就在这个forward中进行啦。 这里给了几个基本操作,但不一定会用到。包括加自环、线性变换、根据度来计算归一化等。 为了便于理解,我们就把forward实现的公式放在这里,有兴趣的可以自行理解:
x
i
(
k
)
=
∑
j
∈
N
(
i
)
∪
{
i
}
1
deg
?
(
i
)
?
deg
?
(
j
)
?
(
Θ
?
x
j
(
k
?
1
)
)
,
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right),
xi(k)?=j∈N(i)∪{i}∑?deg(i)
??deg(j)
?1??(Θ?xj(k?1)?),
propagate
那么MP类中最重要的内容其实就是它的propagate函数了,如果这个函数被调用,那么MP类会隐式的调用如下三个函数:message(), aggregate(), update()。
aggregate()基本上就是我们在init部分规定好的参数了; message()主要实现一开始的公式中的
?
\phi
?这一部分; update()实现了
γ
\gamma
γ部分。
如果在类中不显式的说明这三个函数,那么就是直接输入即输出。因此我们一般都是要至少修改其中一个函数的。在修改的过程中,所有的包含“源和目的节点”这两个属性的变量都可以很方便的表达,比如
x
x
x是表示特征的变量,那么调用
x
j
x_j
xj?就是所有的源节点的特征,
x
i
x_i
xi?就是所有目的节点的特征。
调用类
conv = GCNConv(16, 32)
x = conv(x, edge_index)
这里调用的时候需要说明一下,输入参数表是跟着MP类所属forward()函数需要的参数列表来的。这里,edge_index和特征x的确定,可根据[2]中的讲解来进行。
Reference
[1] https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html#the-messagepassing-base-class [2] https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html
|