IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【torch】搭建GCN的详细介绍 -> 正文阅读

[人工智能]【torch】搭建GCN的详细介绍

一、GCN的原理

简单,也有很多博客在说明!
链接1:https://arxiv.org/abs/1609.02907
链接2:https://mp.weixin.qq.com/s/DJAimuhrXIXjAqm2dciTXg

二、GCN的层代码

import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
class GraphConvolution(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

解释说明:

  1. class GraphConvolution(Module):继承Module类。

  2. class GraphConvolution(Module)中有两个恒常在的函数:__init__()用于初始化参数或者模块等;forward()函数属于输入变量并做运算。

  3. def __init__(self, in_features, out_features, bias=True)这个函数中:

    • super(GraphConvolution, self).__init__():是按照 GraphConvolution的父类Module的初始化方式进行初始化。
    • self.in_features = in_features:用来定义初始化变量,可以在整个class的任意一个函数内部使用。
    • self.weight = Parameter(torch.FloatTensor(in_features, out_features)):定义新的初始化变量。模型中的参数,它是Parameter()类,也是定义GCN的核心操作之一。
      在这里插入图片描述
  4. forward(self, input, adj)函数中输入变量input,adj.

    • support = torch.mm(input, self.weight) 是矩阵乘法,input * self.weight.注意到torch.mm使用范围仅限于二维矩阵。当存在batch变量的时候,也就是infut.shape=[B, N, F]三维形状的时候不使用。建议改为torch.matmul.
    • output = torch.spmm(adj, support)也是矩阵乘法。adj是我们的矩阵输入变量,具有N*N个元素,通常情况下采用稀疏矩阵来保存。spmm是稀疏矩阵的乘法:
      支持 sparse 在前,dense 在后的矩阵乘法
      两个sparse相乘或者dense在前的乘法不支持,
      当然两个dense矩阵相乘是支持的.
      mm是二维矩阵的乘法,不适合用于三维矩阵。
  5. reset_parameters(self)是参数初始化

    • self.weight.size(1)是weight的形状(in_features, out_features)中的out_features
    • math.sqrt(4)=2.0是返回平方根
    • self.weight.data.uniform_(-stdv, stdv):是指weight.data按照均匀分布,上限为-stdv,下限位stdv.
    • 此外对weight的数据初始化方法还有另外一种:init.kaiming_uniform_(self.weight)
  6. __repr__(self)返回该clas的一些介绍。比如
    在这里插入图片描述

三、GCN的搭建

import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution
class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN, self).__init__()
        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout
    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-24 00:32:32  更:2022-03-24 00:33:56 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 1:10:28-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码