一、GCN的原理
简单,也有很多博客在说明! 链接1:https://arxiv.org/abs/1609.02907 链接2:https://mp.weixin.qq.com/s/DJAimuhrXIXjAqm2dciTXg
二、GCN的层代码
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
class GraphConvolution(Module):
def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
解释说明:
-
class GraphConvolution(Module) :继承Module 类。 -
class GraphConvolution(Module) 中有两个恒常在的函数:__init__() 用于初始化参数或者模块等;forward() 函数属于输入变量并做运算。 -
def __init__(self, in_features, out_features, bias=True) 这个函数中:
super(GraphConvolution, self).__init__() :是按照 GraphConvolution的父类Module的初始化方式进行初始化。self.in_features = in_features :用来定义初始化变量,可以在整个class的任意一个函数内部使用。self.weight = Parameter(torch.FloatTensor(in_features, out_features)) :定义新的初始化变量。模型中的参数,它是Parameter() 类,也是定义GCN的核心操作之一。
-
forward(self, input, adj) 函数中输入变量input,adj.
support = torch.mm(input, self.weight) 是矩阵乘法,input * self.weight.注意到torch.mm使用范围仅限于二维矩阵。当存在batch变量的时候,也就是infut.shape=[B, N, F]三维形状的时候不使用。建议改为torch.matmul .output = torch.spmm(adj, support) 也是矩阵乘法。adj是我们的矩阵输入变量,具有N*N个元素,通常情况下采用稀疏矩阵来保存。spmm是稀疏矩阵的乘法: 支持 sparse 在前,dense 在后的矩阵乘法, 两个sparse相乘或者dense在前的乘法不支持, 当然两个dense矩阵相乘是支持的. mm是二维矩阵的乘法,不适合用于三维矩阵。
-
reset_parameters(self) 是参数初始化
self.weight.size(1) 是weight的形状(in_features, out_features)中的out_featuresmath.sqrt(4) =2.0是返回平方根self.weight.data.uniform_(-stdv, stdv) :是指weight.data按照均匀分布,上限为-stdv,下限位stdv.- 此外对weight的数据初始化方法还有另外一种:
init.kaiming_uniform_(self.weight)
-
__repr__(self) 返回该clas的一些介绍。比如
三、GCN的搭建
import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution
class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout):
super(GCN, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nclass)
self.dropout = dropout
def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj)
return F.log_softmax(x, dim=1)
|