GNNExplainer
论文名称:GNNExplainer: Generating Explanations for Graph Neural Networks
论文地址:https://arxiv.org/abs/1903.03894
GNN使用节点的特征和图的结构作为信息沿着边传递。这种整合使得模型的可解释性更加困难。我们建议的模型GNNEXPLAINER,是一种与模型无关的,可以为任何的GNN模型提供解释。GNNEXPLAINER能够识别子图的结构和节点的特征,然后,对样本的实例作出解释。GNNEXPLAINER作为优化器,最大化GNN预测任务和子图结构之间的互信息,能够识别重要的图结构和特征。
GNNEXPLAINER将 trained GNN and its prediction(s)作为输入,返回输入图的子图和对预测结果产生影响的特征(Figure 1)。该方法是与模型无关的,可以解释基于GNN的机器学习任务,包括:节点分类、链路预测、图分类,它可以处理单条和多条样本。当处理单条样本时,GNNEXPLAINER针对该样本进行解释。(a node label, a new link, a graph-level label)。当处理多条样本时,针对该样本集合进行解释。
GNNEXPLAINER用GNN训练时整个图的子图进行解释,该子图最大化与预测值之间互信息。
1. Formulating explanations for graph neural networks
设图为
G
G
G, 边为
E
E
E, 节点为
V
V
V, 节点的特征为
d
d
d 维,
X
=
{
x
1
,
…
,
x
n
}
,
x
i
∈
R
d
\mathcal{X}=\left\{x_{1}, \ldots, x_{n}\right\}, x_{i} \in \mathbb{R}^{d}
X={x1?,…,xn?},xi?∈Rd,其中,
n
n
n是节点的数量。
f
f
f是节点label的映射函数。
f
:
V
?
{
1
,
…
,
C
}
f: V \mapsto\{1, \ldots, C\}
f:V?{1,…,C}, 将
V
V
V中的每个节点映射为
C
C
C类, GNN模型
Φ
\Phi
Φ在所有训练节点上进行优化,对新的节点进行预测。
1.1 Background on graph neural networks
在
l
l
l层, GNN模型包括关键三步。(1)第一步,计算节点对
(
v
i
,
v
j
)
(v_i,v_j)
(vi?,vj?)之间的message,
h
i
l
?
1
\mathbf{h}_i^{l-1}
hil?1?和
h
j
l
?
1
\mathbf{h}_j^{l-1}
hjl?1?分别是前一层节点
i
i
i和节点
j
j
j的表示,
r
i
j
r_{ij}
rij?是两个节点之间的关系:
m
i
j
l
=
MSG
?
(
h
i
l
?
1
,
h
j
l
?
1
,
r
i
j
)
m_{i j}^{l}=\operatorname{MSG}\left(\mathbf{h}_{i}^{l-1}, \mathbf{h}_{j}^{l-1}, r_{i j}\right)
mijl?=MSG(hil?1?,hjl?1?,rij?)(2),第二步,对于每个节点
v
i
v_i
vi?, GNN汇总aggregates它的邻居
N
v
i
\mathcal{N}_{v_i}
Nvi??的信息, aggregated message
M
i
M_i
Mi?的计算方式:
M
i
l
=
AGG
?
(
{
m
i
j
l
∣
v
j
∈
N
v
i
}
)
M_{i}^{l}=\operatorname{AGG}\left(\left\{m_{i j}^{l} \mid v_{j} \in \mathcal{N}_{v_{i}}\right\}\right)
Mil?=AGG({mijl?∣vj?∈Nvi??}). 其中
N
v
i
\mathcal{N}_{v_i}
Nvi??是节点
v
i
v_i
vi?的邻居的节点,它的定义不同会产生不同的GNN变种。(3)GNN 使用聚合函数
M
i
l
M_i^l
Mil?聚合节点
v
i
v_i
vi?的representation
h
i
l
?
1
\mathbf{h}_i^{l-1}
hil?1?, 然后进行非线性转换获得节点
v
i
v_i
vi?的节点在
l
l
l层表示
h
i
l
\mathbf{h}_i^l
hil?:
h
i
l
=
UPDATE
?
(
M
i
l
,
h
i
l
?
1
)
\mathbf{h}_{i}^{l}=\operatorname{UPDATE}\left(M_{i}^{l}, \mathbf{h}_{i}^{l-1}\right)
hil?=UPDATE(Mil?,hil?1?), 然后经过
L
L
L层获得最后的输出:
z
i
=
h
i
L
\mathbf{z}_{i}=\mathbf{h}_{i}^{L}
zi?=hiL?。
1.2 GNNEXPLAINER: Problem formulation
我们处理问题的关键是节点
v
v
v的计算,将节点邻居的信息进行汇总,产生节点
v
v
v的预测
y
^
\hat{y}
y^?。节点
v
v
v的最终输出为
z
\mathbf{z}
z. 图
G
c
(
v
)
G_c(v)
Gc?(v)的计算与临接矩阵
A
c
(
v
)
∈
{
0
,
1
}
n
×
n
A_{c}(v) \in\{0,1\}^{n \times n}
Ac?(v)∈{0,1}n×n和节点特征
X
c
(
v
)
=
{
x
j
∣
v
j
∈
G
c
(
v
)
}
X_{c}(v)=\left\{x_{j} \mid v_{j} \in G_{c}(v)\right\}
Xc?(v)={xj?∣vj?∈Gc?(v)}有关。GNN模型
Φ
\Phi
Φ学习
Y
Y
Y的概率分布
P
Φ
(
Y
∣
G
c
,
X
c
)
P_{\Phi}\left(Y \mid G_{c}, X_{c}\right)
PΦ?(Y∣Gc?,Xc?), 其中
Y
Y
Y代表标签
1
,
?
?
,
C
{1,\cdots,C}
1,?,C随机变量,即每个节点属于
C
C
C类中每个类别的概率。
GNN的预测
y
^
=
Φ
(
G
c
(
v
)
,
X
c
(
v
)
)
\hat{y}=\Phi\left(G_{c}(v), X_{c}(v)\right)
y^?=Φ(Gc?(v),Xc?(v)),模型
Φ
\Phi
Φ主要是由图的结构信息
G
c
(
v
)
G_c(v)
Gc?(v)和节点的特征
X
c
(
v
)
X_c(v)
Xc?(v)决定的。一般地, GNNEXPLAINER将预测值
y
^
\hat{y}
y^? 解释为
(
G
S
,
X
S
F
)
\left(G_{S}, X_{S}^{F}\right)
(GS?,XSF?), 其中
G
S
G_S
GS?是预测图的子图,
X
S
X_S
XS?是
G
S
G_S
GS?的节点特征,
X
S
F
X_S^F
XSF?是
G
S
G_S
GS?中节点的子集(通过
F
F
F进行mask,
X
S
F
=
{
x
j
F
∣
v
j
∈
G
S
}
X_{S}^{F}=\{x_{j}^{F} \mid v_{j} \in G_S\}
XSF?={xjF?∣vj?∈GS?} )。
2 GNNEXPLAINER
接下来,我们介绍一下 GNNEXPLAINER 如何在单条(2.1, 2.2)和多条(2.3)上的预测进行模型解释。最后介绍GNNEXPLAINER在机器学习任务上的应用(2.4),如链路预测和图分类。
2.1 Single-instance explanations
给定一个节点
v
v
v, 我们的目标是识别子图
G
S
?
G
c
G_{S} \subseteq G_{c}
GS??Gc?和相关特征
X
S
=
{
x
j
∣
v
j
∈
G
S
}
X_S=\left\{x_{j} \mid v_{j} \in G_{S}\right\}
XS?={xj?∣vj?∈GS?}, 这些对于GNN预测
y
^
\hat{y}
y^? 是非常重要的。现在,我们假设
X
S
X_S
XS?是子集节点的特征,
d
d
d维。在2.2将要讨论哪一维特征能够对模型进行解释。使用互信息
M
I
MI
MI衡量重要性, GNNEXPLAINER优化框架如下:
max
?
G
S
M
I
(
Y
,
(
G
S
,
X
S
)
)
=
H
(
Y
)
?
H
(
Y
∣
G
=
G
S
,
X
=
X
S
)
(1)
\max _{G_{S}} M I\left(Y,\left(G_{S}, X_{S}\right)\right)=H(Y)-H\left(Y \mid G=G_{S}, X=X_{S}\right)\tag{1}
GS?max?MI(Y,(GS?,XS?))=H(Y)?H(Y∣G=GS?,X=XS?)(1) 对于节点
v
v
v,
M
I
MI
MI是衡量是当计算图被限制在子图
G
S
G_S
GS?,节点特征被限制在
X
S
X_S
XS?时,预测概率
y
^
=
Φ
(
G
c
,
X
c
)
\hat{y}=\Phi\left(G_{c}, X_{c}\right)
y^?=Φ(Gc?,Xc?)的变化。
举例来说,
v
j
∈
G
c
(
v
i
)
,
v
j
≠
v
i
v_{j} \in G_{c}\left(v_{i}\right), v_{j} \neq v_{i}
vj?∈Gc?(vi?),vj??=vi?,如果移除
v
j
v_j
vj?,
y
^
i
\hat{y}_i
y^?i?的概率显著下降, 则节点
v
j
v_j
vj?就是很好反事实解释。类似地,
(
v
j
,
v
k
)
∈
G
c
(
v
i
)
,
v
j
,
v
k
≠
v
i
\left(v_{j}, v_{k}\right) \in G_{c}\left(v_{i}\right), v_{j}, v_{k} \neq v_{i}
(vj?,vk?)∈Gc?(vi?),vj?,vk??=vi?,如果移除
v
j
v_j
vj?和
v
k
v_k
vk?之间的边,
y
^
i
\hat{y}_i
y^?i? 的预测概率值显著下降,则
v
j
v_j
vj?和
v
k
v_k
vk?之间的边是很好的反事实解释。
在Eq.(1)中, 交叉项
H
(
Y
)
H(Y)
H(Y)是常数,因为模型
Φ
\Phi
Φ已经训练好,因此,最大化
Y
Y
Y和
(
G
S
,
X
S
)
(G_S,X_S)
(GS?,XS?)之间的互信息等于最小化条件熵
H
(
Y
∣
G
=
G
S
,
X
=
X
S
)
H\left(Y \mid G=G_{S}, X=X_{S}\right)
H(Y∣G=GS?,X=XS?),如下:
H
(
Y
∣
G
=
G
S
,
X
=
X
S
)
=
?
E
Y
∣
G
S
,
X
S
[
log
?
P
Φ
(
Y
∣
G
=
G
S
,
X
=
X
S
)
]
(2)
H\left(Y \mid G=G_{S}, X=X_{S}\right)=-\mathbb{E}_{Y \mid G_{S}, X_{S}}\left[\log P_{\Phi}\left(Y \mid G=G_{S}, X=X_{S}\right)\right]\tag{2}
H(Y∣G=GS?,X=XS?)=?EY∣GS?,XS??[logPΦ?(Y∣G=GS?,X=XS?)](2) 以子图
G
S
G_S
GS?对
y
^
\hat{y}
y^? 进行解释, 实际上最小化
Φ
\Phi
Φ的不确定性。实际上,最大化概率
y
^
\hat{y}
y^?。为了给出简介的解释,我们给
G
S
G_S
GS?增加限制:
∣
G
S
∣
≤
K
M
\left|G_{S}\right| \leq K_{M}
∣GS?∣≤KM?, 其中
G
S
G_S
GS?最多有
K
M
K_M
KM?个节点。这意味着, GNNEXPLAINER通过
K
M
K_M
KM?边消除
G
C
G_C
GC?的噪声,给出预测的最大互信息。
**GNNEXPLAINER’s optimization framework.**对于
G
c
G_c
Gc?来说,用于解释
y
^
\hat{y}
y^? 的子图
G
S
G_S
GS?非常多,直接处理是非常困难的。我们考虑部分邻接矩阵的方式:
A
S
[
j
,
k
]
≤
A
c
[
j
,
k
]
A_{S}[j, k] \leq A_{c}[j, k]
AS?[j,k]≤Ac?[j,k],其中,
A
S
∈
[
0
,
1
]
n
×
n
A_{S} \in[0,1]^{n \times n}
AS?∈[0,1]n×n, 对于所有
j
,
k
j,k
j,k增加以上限制。这个近似可以理解为子图是
G
c
G_c
Gc?的近似。我们将
G
S
~
G
G_{S} \sim \mathcal{G}
GS?~G看做图的随机变量,目标函数Eq.(2)可以变换为:
min
?
G
E
G
S
~
G
H
(
Y
∣
G
=
G
S
,
X
=
X
S
)
(3)
\min _{\mathcal{G}} \mathbb{E}_{G_{S} \sim \mathcal{G}} H\left(Y \mid G=G_{S}, X=X_{S}\right)\tag{3}
Gmin?EGS?~G?H(Y∣G=GS?,X=XS?)(3) 由于凸的假设,使用Jensen不等式给出上限:
min
?
G
H
(
Y
∣
G
=
E
G
[
G
S
]
,
X
=
X
S
)
(4)
\min _{\mathcal{G}} H\left(Y \mid G=\mathbb{E}_{\mathcal{G}}\left[G_{S}\right], X=X_{S}\right)\tag{4}
Gmin?H(Y∣G=EG?[GS?],X=XS?)(4) 在实际中,由于神经网络的复杂性,凸的假设是不成立的,但是,最小化这个目标函数和正则项通常会带来比较的解释。
为了估计
E
G
\mathbb{E}_{\mathcal{G}}
EG?, 我们将其分解为multivariate Bernoulli distribution:
P
G
(
G
S
)
=
∏
(
j
,
k
)
∈
G
c
A
S
[
j
,
k
]
P_{\mathcal{G}}\left(G_{S}\right)=\prod_{(j, k) \in G_{c}} A_{S}[j, k]
PG?(GS?)=∏(j,k)∈Gc??AS?[j,k],其中
A
S
A_S
AS?的
(
j
,
k
)
-th
(j,k)\text{-th}
(j,k)-th条目代表边
(
v
j
,
v
k
)
(v_j,v_k)
(vj?,vk?)之间是否有边存在。我们经验发现,使用正则项可以使得分解值收敛局部最小,即使GNN是非凸的。将Equation 4中
E
G
[
G
S
]
\mathbb{E}_G[G_S]
EG?[GS?]替换为masking 邻接矩阵
A
c
⊙
σ
(
M
)
A_{c} \odot \sigma(M)
Ac?⊙σ(M)进行优化 ,
M
∈
R
n
×
n
M \in \mathbb{R}^{n \times n}
M∈Rn×n指的是Mask,
⊙
\odot
⊙指element-wise乘积,
σ
\sigma
σ 指的是将mask映射为
[
0
,
1
]
n
×
n
[0,1]^{n \times n}
[0,1]n×n.
在一些应用中,用户更关注如何将训练的模型用于预测想要的label。我们需要修改Equation4:
min
?
M
?
∑
c
=
1
C
1
[
y
=
c
]
log
?
P
Φ
(
Y
=
y
∣
G
=
A
c
⊙
σ
(
M
)
,
X
=
X
c
)
(5)
\min _{M}-\sum_{c=1}^{C} \mathbb{1}[y=c] \log P_{\Phi}\left(Y=y \mid G=A_{c} \odot \sigma(M), X=X_{c}\right)\tag{5}
Mmin??c=1∑C?1[y=c]logPΦ?(Y=y∣G=Ac?⊙σ(M),X=Xc?)(5) 该公式 使用Mask机制,将
σ
(
M
)
\sigma(M)
σ(M)和
A
c
A_c
Ac?进行乘积,移除
M
M
M中小的值,以达到用子图
G
S
G_S
GS?解释GNN对节点
v
v
v的预测值
y
^
\hat{y}
y^?进行解释的作用。
2.2 Joint learning of graph structural and node feature information
为了识别节点特征对预测值
y
^
\hat{y}
y^?的重要性, GNNEXPLAINER学习
G
S
G_S
GS?节点特征
F
F
F选择器。与节点所有特征不同,
X
S
=
{
x
j
∣
v
j
∈
G
S
}
X_{S}=\left\{x_{j} \mid v_{j} \in G_{S}\right\}
XS?={xj?∣vj?∈GS?}, GNNEXPLAINER考虑
G
S
G_S
GS?的子集特征
X
S
F
X_{S}^{F}
XSF?, 特征的选择通过二值特征选择器
F
∈
{
0
,
1
}
d
F \in\{0,1\}^{d}
F∈{0,1}d(Figure 2B):
X
S
F
=
{
x
j
F
∣
v
j
∈
G
S
}
,
x
j
F
=
[
x
j
,
t
1
,
…
,
x
j
,
t
k
]
?for?
F
t
i
=
1
(6)
X_{S}^{F}=\left\{x_{j}^{F} \mid v_{j} \in G_{S}\right\}, \quad x_{j}^{F}=\left[x_{j, t_{1}}, \ldots, x_{j, t_{k}}\right] \text { for } F_{t_{i}}=1\tag{6}
XSF?={xjF?∣vj?∈GS?},xjF?=[xj,t1??,…,xj,tk??]?for?Fti??=1(6) 其中,
x
j
F
x_j^F
xjF?是没有被
F
F
F mask out的节点特征。
(
G
S
,
X
S
)
(G_S,X_S)
(GS?,XS?)进行联合优化以最大化互信息:
max
?
G
S
,
F
M
I
(
Y
,
(
G
S
,
F
)
)
=
H
(
Y
)
?
H
(
Y
∣
G
=
G
S
,
X
=
X
S
F
)
(7)
\max _{G_{S}, F} M I\left(Y,\left(G_{S}, F\right)\right)=H(Y)-H\left(Y \mid G=G_{S}, X=X_{S}^{F}\right)\tag{7}
GS?,Fmax?MI(Y,(GS?,F))=H(Y)?H(Y∣G=GS?,X=XSF?)(7) 该方程对Eq.(1)目标函数进行调整,同时考虑结构和节点特征两个方面,对预测
y
^
\hat{y}
y^?进行解释。
Learning binary feature selector
F
F
F. 我们设
X
S
=
X
S
⊙
F
X_S=X_S\odot F
XS?=XS?⊙F, 其中
F
F
F是需要学习的参数。如果某个特征不重要,GNN会使得它的权重为0. 实际上,若果这个特征不重要,移除这个特征预测值不会有太大的变化,如果这个特征重要,预测值会显著下降。但是这种方法会忽略一些特征很重要,但是取值接近0。为了解决这个问题,在训练的过程中,我们使用蒙特卡洛从节点
X
S
X_S
XS?的边缘经验分布抽样。然后,我们使用参数化技巧进行反向传播,学习feature mask
F
F
F。特别地,随机变量
X
X
X计算如下:
X
=
Z
+
(
X
S
?
Z
)
⊙
F
X=Z+\left(X_{S}-Z\right) \odot F
X=Z+(XS??Z)⊙F s.t.
∑
j
F
j
≤
K
F
\sum_{j} F_{j} \leq K_{F}
∑j?Fj?≤KF?,其中
Z
Z
Z是从经验分布抽样的
d
d
d维随机变量,
K
F
K_F
KF?是保留的最大特征的数量,是可学习的参数。
Integrating additional constraints into explanations. 为了强化可解释性,我们可以对Eq.(7)增加正则项。例如,为了使得structural and node feature masks to be discrete, 我们使用element-wise entropy,或者增加特定领域限制,如,拉格朗日正则项。我们也可以将mask的元素求和,作为正则项。
最后,需要注意的是对GNN进行解释必须是一个有效的计算图。因为解释
(
G
S
,
X
S
)
\left(G_{S}, X_{S}\right)
(GS?,XS?)必须允许GNN的message能够流向节点
v
v
v, 以此来预测
y
^
\hat{y}
y^?. 重要的是, GNNEXPLAINER 自动可以提供有效计算图,因为它会在整个图上优化structural mask。如果边是没有连接的,它不会被选择,不会影响最终GNN预测。
2.3 Multi-instance explanations through graph prototypes
我们的目标是分析子图如何对一类标签进行解释, GNNEXPLAINER能够基于 graph alignments and prototypes对多实例进行解释。
首先,我们先选择一个类别
c
c
c的参考样本样本点,例如,将其他节点embedding的均值赋值
c
c
c。我们利用
G
S
(
v
c
)
G_S(v_c)
GS?(vc?)对
v
c
v_c
vc?进行解释,然后将解释赋值给这个类别
c
c
c的其他节点。如果在大图中,进行匹配是非常具有挑战性的。但是单条样本产生是一个小图,而且near-optimal pairwise graph matchings是非常高效的。
其次,我们将邻接矩阵进行汇总给a graph prototype
A
proto
A_{\text{proto}}
Aproto?, 例如计算中位数.
A
proto
A_{\text{proto}}
Aproto?用于识别graph patterns,它在同类别中是共享的。可以用于预测和模型解释。
2.4 GNNEXPLAINER model extensions
Any machine learning task on graphs. 除了能够解释节点分类,在不需要修改优化算法的情况下,GNNEXPLAINER可以用于链路预测和图分类。当对
(
v
j
,
v
k
)
(v_j,v_k)
(vj?,vk?)进行链路预测时,GNNEXPLAINER会学习
X
S
(
v
j
)
X_S(v_j)
XS?(vj?)和
X
S
(
v
k
)
X_S(v_k)
XS?(vk?)两个mask。当进行图分类时,会将我们想解释图的所有邻接矩阵进行union.
Any GNN model. 现在GNN主要基于 message passing构建各种结构, GNNEXPLAINER能够对它们进行解释。
Computational complexity. GNNEXPLAINER的优化取决于计算图
G
c
G_c
Gc?的大小,
G
c
(
v
)
G_c(v)
Gc?(v)的邻接矩阵
A
c
(
v
)
A_c(v)
Ac?(v)等于mask
M
M
M的大小,需要GNNEXPLAINER学习。但是,通常来说,计算图相对较小, 即使输入大图 ,GNNEXPLAINER也能对其进行有效的解释。
|