IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【AM-GCN】代码解读之主程序 -> 正文阅读

[人工智能]【AM-GCN】代码解读之主程序

[前篇]:
初了解(一)


一、导入库

import torch.nn.functional as F #常用函数
import torch.optim as optim     #优化算法
from utils import *
from models import SFGCN        #主模型
from sklearn.metrics import f1_score #计算F1分数,也称为平衡F分数或F测度
import os
import argparse
from config import Config
import torch
import numpy as np

解释说明

  1. torch.optim
  2. f1_score

二、参数读取和设置

if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    parse = argparse.ArgumentParser()
    parse.add_argument("-d", "--dataset", help="dataset", type=str, required=True)
    parse.add_argument("-l", "--labelrate", help="labeled data for train per class", type = int, required = True)
    args = parse.parse_args()
    config_file = "./config/" + str(args.labelrate) + str(args.dataset) + ".ini"
    config = Config(config_file)

    cuda = not config.no_cuda and torch.cuda.is_available() #cuda可否能用

    use_seed = not config.no_seed # cuda
    if use_seed:
        np.random.seed(config.seed)
        torch.manual_seed(config.seed)
        if cuda:
            torch.cuda.manual_seed(config.seed)

解释说明

  1. os.environ[“CUDA_VISIBLE_DEVICES”] = "2"设置使用的标号为"2"的显卡。os.environ[‘环境变量名称’]=‘环境变量值’ :其中key和value均为string类型
  2. argparse 参数配置的方法,可以打开cmd来设置参数的值,具体使用方法见5中链接。

    1)import argparse 首先导入模块
    ?2)parser = argparse.ArgumentParser() 创建一个解析器
    3)parser.add_argument() 向该解析器中添加你要关注的命令行参数和选项
    4)parser.parse_args() 进行解析

  3. config = Config(config_file) 是调用的config.py中的class类函数并实例化。其可以读取 'config_file '文件路径中的参数文件并配置参数。
  4. 'config_file '文件路径,涉及了 标签率args.labelrate 数据集args.dataset.通过原文可以知道是

    三个标签率(即每类20、40、60个标签节点)
    六个数据集(见《初了解一》)

  5. 区别:import argparse 和 import configparser,↙?具体见链接。
    1. 可以理解为都是参数配置的模块,两者并非不可互相替代。
    2. 显著区别是使用方法不同。通俗来讲,
      configparser更像是把参数进行分类归纳整理在一个.ini的文件中,通过读取文件的方式获得文件中的各项参数。
      argparse 更像是机器的开关,在开机运作的时候设置各项功能。

6.参数来源以及对照值

对象读取命令读取文件参数值
config.no_cudaconf.getboolean("Model_Setup", "no_cuda")20acmFalse
config.no_seedgetboolean("Model_Setup", "no_seed")20acmFalse
config.seedgetint("Model_Setup", "seed")20acm123
  1. 函数解析(具体见超链接):

np.random.seed()函数:用于生成指定随机数。
torch.manual_seed()函数:CPU生成随机数的种子,方便下次复现实验结果
torch.cuda.manual_seed()函数:固定生成随机数的种子,使得每次运行该 .py 文件时生成的随机数相同

  1. cuda–至结束。含义为是否使用cuda以及有可用的cuda,如果使用cuda则设置随机数种子123,否则使用cpu设置随机数种子123.

二、数据读取

    sadj, fadj = load_graph(args.labelrate, config)
    features, labels, idx_train, idx_test = load_data(config)
  • fadj是特征图矩阵 【3025,3025】–> 在acm中有7282个非零值
  • sadj是结构图矩阵 【3025,3025】–>有26256个非零值
  • features是特征向量—>(3025, 1870)
  • labels是标签向量 —>(3025,1)
  • idx_train是训练数据索引 -->(1000,)
  • idx_test是测试数据索引 -->(60,)
    以上数据结构以acm数据为例。且都是torch中的结构

三、模型准备

  1. 模型实例化
  2. 如果cuda可用,将数据放置到cuda上。
   model = SFGCN(nfeat = config.fdim,
              nhid1 = config.nhid1,
              nhid2 = config.nhid2,
              nclass = config.class_num,
              n = config.n,
              dropout = config.dropout)
    if cuda:
        model.cuda()
        features = features.cuda()
        sadj = sadj.cuda()
        fadj = fadj.cuda()
        labels = labels.cuda()
        idx_train = idx_train.cuda()
        idx_test = idx_test.cuda()
    optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

四、模型训练

    acc_max = 0
    f1_max = 0
    epoch_max = 0
    for epoch in range(config.epochs):
        loss, acc_test, macro_f1, emb = train(model, epoch)
        if acc_test >= acc_max:
            acc_max = acc_test
            f1_max = macro_f1
            epoch_max = epoch
    print('epoch:{}'.format(epoch_max),
          'acc_max: {:.4f}'.format(acc_max),
          'f1_max: {:.4f}'.format(f1_max))

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章           查看所有文章
加:2022-04-09 18:22:35  更:2022-04-09 18:27:23 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:11:05-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码