Inductive Representation Learning on Large Graphs
本文发表于NIPS 2017.GraphSAGE在时间和表现上都获得了最优。
abstract
大型图的低维嵌入具有很大的应用价值。但是现存的方法都要求图中的所有节点在训练时都存在,本质上是transductive的,无法自然地泛化到未见过的节点上。因此我们提出了GraphSAGE,一种inductive的框架,利用节点特征信息来为未见过的数据生成节点embeddings。不是通过单独训练每个节点的embedding,而是通过从节点的局部邻居上采样和聚合特征来学习函数生成embedding。我们的算法在三个inductive 节点分类的数据集上超越了其他的strong baselines:我们基于引文和 Reddit 帖子的演化信息图中的数据对看不见的节点的类别进行分类,同时我们表明我们的算法可以推广到完全未知的图通过使用蛋白质-蛋白质交互的多图数据集。
1.introduction
节点嵌入方法的基础idea是:使用维度下降方法来将节点的邻居的高纬信息蒸馏到一个密集的向量嵌入上。然而,之前的工作都聚焦于从单个固定图赏学习节点embedding,但是许多实际应用需要从未知节点上快速获取embedding,甚至是从全新的图上。这种inductive的能力对于高吞吐量,量产的机器学习系统是必要的,当进行图演化和持续碰到未见过的节点之时。一种生成节点嵌入的归纳方法(inductive approach)也促进具有相同特征形式的图之间的泛化:例如,可以训练来自模型生物的蛋白质-蛋白质交互图上的嵌入生成器,然后使用这种经过训练的模型轻松地为在新生物上收集的数据生成节点嵌入。
这种归纳的节点嵌入方法很难,与transductive setting相比,泛化未知节点需要将新观察到的图与原来训练好的算法的节点的嵌入进行对齐。inductive 框架必须学会识别节点邻居的结构属性以及揭示节点的局部属性,以及节点的全局位置。
现存的节点嵌入方法多数本质上是transductive。多数通过矩阵分解的方法在单个固定图上完成对每个节点的嵌入工作。然而这些方法可以被改进为inductive setting,因为这些改进需要巨大的计算量。总之,现存的GCN都是应用于固定图上的transductive setting。本次研究我们将GCN扩展到无监督的inductive learning,以及将GCN推广到使用可训练的聚合函数(除了卷积可训练之外)。
present work:
我们提出了一个通用框架GraphSAGE(sample and aggregate),来进行inductive node embedding。不像传统的基于矩阵分解的embedding方法,我们利用了节点的特征,为了对未知节点进行嵌入方程的学习。通过将节点特征加入学习算法,我们同时学习了邻居节点的拓扑结构以及邻居节点的特征分布。我们的方法也可以用于没有特征信息的图。
不是对每个节点学习一个embedding向量,我们训练一个aggregator functions的集合来从节点的局部邻居上聚合特征信息。每个aggregator function从不同跳数的邻居上聚合信息。在测试阶段,通过学习到的聚合函数来为没见过的节点生成embedding。我们设计了一个无监督损失函数来允许GraphSAGE在无具体任务的监督下进行训练,但是它也可以在有监督的情况下进行训练。
通过实验,我们验证了我们的方法能够为没见过的节点生成embeddings以及很大程度上超越了其他baselines。最后,我们还验证了我们的方法的表达能力,通过理论分析,GraphSAGE具有学习图中节点的结构信息的能力,尽管它是基于节点的学习。
3.proposed method:GraphSAGE
我们方法的核心idea是:如何从节点的局部邻居学习去聚合信息。
3.1 embedding generation(forward propagation)algorithm
假设我们已经学习了
K
K
K个aggregator functions的参数(
A
G
G
R
E
G
A
T
O
R
k
,
?
k
∈
{
1
,
.
.
.
,
K
}
AGGREGATOR_k,\forall k \in \{1,...,K\}
AGGREGATORk?,?k∈{1,...,K}),它从邻居节点中聚合信息,其实就是一个权重矩阵的集合
W
k
,
?
k
∈
{
1
,
.
.
.
,
K
}
W^k,\forall k \in \{1,...,K\}
Wk,?k∈{1,...,K},用于在模型的不同层之间聚合信息。
算法1中当迭代过程持续进行时,节点获得越来越信息直到覆盖全图的信息。
如何将算法1推广到minibatch setting:给定输入节点的集合,我们首先采样出需要的邻居集合,然后再进行邻居的聚合操作,但是不是对所有的节点都进行迭代操作,而是只计算满足每层递归的表征。
3.1.1 relation to the Weisfeiler-Lehman Isomorphism Test
Weisfeiler-Lehman Isomorphism Test:用于测试两个图的同构性。如果输出一致则表明两个图是同构的。这种test在某些情况下是有误的,但是在大多数情况下是适用的。GraphSAGE可以被视作是WL test的连续的逼近,但是被用于节点的embedding生成,而非是同构性验证。WL test是GraphSAGE学习邻居节点的拓扑结构的理论基础。
如果将
K
=
∣
V
∣
K=|V|
K=∣V∣,将所有的权重矩阵设为一致,使用合适的无非线性的hash函数来替代聚合函数我们就得到了Weisfeiler-Lehman Isomorphism Test。
3.1.2 neighborhood definition
在本文中,我们采样了固定数目的邻居,而不是使用全部的邻居节点。per-batch的空间和时间复杂度为
O
(
∑
i
=
1
K
S
i
)
O(\sum_{i=1}^K S_i)
O(∑i=1K?Si?),其中
S
i
,
i
∈
{
1
,
.
.
.
,
K
}
S_i,i \in \{1,...,K\}
Si?,i∈{1,...,K}以及
K
K
K都是用户自定义的常数。
3.2 learning the parameters of GraphSAGE
进行完全无监督的学习。这种基于图的损失函数使得相近的节点具有相似的表征,距离远的节点的表示则高度不同。
J
G
(
z
u
)
=
?
l
o
g
(
σ
(
z
u
T
z
v
)
)
?
Q
?
E
v
n
~
P
n
(
v
)
l
o
g
(
σ
(
?
z
u
T
z
v
n
)
)
.
J_G(z_u)=-log(\sigma(z_u^Tz_v))-Q\cdot\mathbb{E}_{v_n\sim P_n(v)}log(\sigma(-z_u^Tz_{v_n})).
JG?(zu?)=?log(σ(zuT?zv?))?Q?Evn?~Pn?(v)?log(σ(?zuT?zvn??)).
其中
v
v
v是一个在固定长度的随机游走上的
u
u
u附近共现的节点,
P
n
P_n
Pn?是一个负采样分布,
Q
Q
Q代表负采样的样本数目。
这种无监督的损失函数可以随时被替换、增强或者是task-specific的目标(如交叉熵损失)。
3.3 aggregator architectures
算法1中的聚合函数必须作用于无序的节点集合,因此聚合函数必须是对称的(输入的节点顺序不影响结果)同时保持可训练以及保持高表达能力。聚合函数的对称属性保证了我们的神经网络模型可以被训练和应用到任意顺序的节点邻居特征集合之上。我们验证了一下三种候选的聚合函数:
3.3.1 mean aggregator
可以将算法1的四五行替换成以下公式:
h
v
k
←
σ
(
W
?
M
E
A
N
(
{
h
v
k
?
1
}
?
{
h
u
k
?
1
,
?
u
∈
N
(
u
)
}
)
)
.
h_v^k \gets \sigma(W \cdot MEAN(\{h_v^{k-1}\}\bigcup\{h_u^{k-1},\forall u \in N(u)\})).
hvk?←σ(W?MEAN({hvk?1?}?{huk?1?,?u∈N(u)})).
我们将这种改进后的基于均值的聚合器称为卷积,因为它是一个粗糙的对局部频域卷积的线性估计。这种聚合信息的方式和我们其他的方式最大的差别就是缺少了concatenation操作,这种操作实际上是一种在不同层之间的skip-connection,这样可以获得很大的性能提升。
3.3.2 LSTM aggregator
与mean aggregator相比,这种基于LSTM的聚合器具有更好的表达能力。但是需要注意的是,LSTM本质上并不是对称的(它不是计算不变的,受输入顺序的干扰),因为LSTM的输入是一个序列。我们改进lstm到无序的集合上,通过将lstm应用到对节点邻居的随机洗牌上。
3.3.3 pooling aggregator
最后介绍的这种聚合器同时具有对称性和可训练性。在pooling方法下,每个邻居向量被独立地通过一个全连接的神经网络。
A
G
G
R
E
G
A
T
E
k
p
o
o
l
=
m
a
x
(
{
σ
(
W
p
o
o
l
h
u
i
k
+
b
)
,
?
u
i
∈
N
(
v
)
}
)
,
AGGREGATE_k^{pool}=max(\{\sigma(W_{pool}h_{u_i}^k+b),\forall u_i \in N(v)\}),
AGGREGATEkpool?=max({σ(Wpool?hui?k?+b),?ui?∈N(v)}),
本质上,这里的MLP可以被视为是对邻居集合的节点表征计算特征。事实上,任何对称的向量方程都可以替代这里的max操作,例如mean,但是事实上证明,这里的两种操作并没有什么实际效果的差异。
4.experiments
4.1 inductive learning on evolving graphs:citation and Reddit data
最开始的两个实验是在演化信息图上的节点分类问题,这通常和高吞吐量的系统相关,经常碰到unseen data。GraphSAGE超越了其他baseline模型的表现。
4.2 generalizing graphs:protein-protein interactions
GraphSAGE超越了其他baseline模型的表现。
4.3 runtime and parameter sensitivity
尽管通过下采样引入了高方差,GraphSAGE仍然能够保持很强的预测准确度,并且大大节省了运行时间。
4.4 summary comparison between the different aggregator architectures
LSTM-和pool-based aggregator表现最佳。但是LSTM-based要比poll慢得多。
5. theoretical analysis
clustering coefficients:反映了节点的一跳邻居的聚集程度。我们证明了算法1具有对任意程度的准确度的图的clustering coefficients的预测的能力。 theorem 1:存在一组参数
θ
?
\theta^*
θ?for 算法1,使得在
K
=
4
K=4
K=4的迭代之后:
∣
z
v
?
c
v
∣
<
?
,
?
v
∈
V
.
|z_v-c_v|<\epsilon,\forall v \in V.
∣zv??cv?∣<?,?v∈V.
z
v
z_v
zv?为算法1的最终输出,而
c
v
c_v
cv?为节点的clustering coefficients。说明了GraphSAGE可以学习到局部的图结构,即使节点特征都是随机采样出来的。上述定理的证明过程的基础idea是每个节点都具有独特的特征表示,我们可以学习来映射节点到指示向量以及识别每个节点的邻居。pool aggregator超越了GCN和mean-based aggregator。
|