IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Task07:超大规模数据集类的创建&图预测任务实践 -> 正文阅读

[人工智能]Task07:超大规模数据集类的创建&图预测任务实践

Task07:超大规模数据集类的创建&图预测任务实践

本文参考datawhale开源学习资料

一、超大规模数据集类的创建

1. Dataset基类简介

  • InMemoryDataset:会一次性把数据全部加载到内存中。
  • Dataset: 每次加载一个数据到内存中,在数据集比较大的时候使用。

torch_geometric.data.Dataset需要实现另外的两个方法

  • len():返回数据集中的样本的数量。
  • get():实现加载单个图的操作。注意:在内部,__getitem__()返回通过调用get()来获取Data对象,并根据transform参数对它们进行选择性转换。

下面让我们通过一个简化的例子看继承torch_geometric.data.Dataset基类的规范

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url

class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

其中,每个Data对象在process()方法中单独被保存,并在get()中通过指定索引进行加载。

FAQ

  • Q:如何跳过download/process?

  • A:对于无需下载数据集原文件的情况,我们不重写(override)download方法即可跳过下载。对于无需对数据集做预处理的情况,我们不重写process方法即可跳过预处理。

  • Q:必须使用Dataset类吗?

  • A:通过下面的方式,我们可以不用定义一个Dataset类,而直接生成一个Dataloader对象,直接用于训练:

from torch_geometric.data import Data, DataLoader

data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)

torch_geometric.data.DataLoader继承自torch.utils.data.DataLoader,通过torch_geometric.data.DataLoader可以方便地使用 mini-batch。

我们也可以通过下面的方式将一个列表的Data对象组成一个Batch

from torch_geometric.data import Data, Batch

data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, batch_size=32)
  • torch_geometric.data.Batch继承自torch_geometric.data.Data,并且多了一个属性:batchbatch是一个列向量,它将每个元素映射到每个 mini-batch 中的相应图:batch = [ 0 ? 0 1 ? n ? 2 n ? 1 ? n ? 1 ] ? =\left[\begin{array}{cccccccc}0 & \cdots & 0 & 1 & \cdots & n-2 & n-1 & \cdots & n-1\end{array}\right]^{\top} =[0???0?1???n?2?n?1???n?1?]?。这种记录对于图输出阶段是非常有必要的,相同索引的节点属于同一个图,如何经过torch_scatter进行节点表示的summeanmax等操作。
  • Batch还有个方法to_data_list,即通过Batch重构出一个包含Data的列表,但是该方法的使用前提是Batch是由from_data_list创建而成的。

2. 图样本封装成批(BATCHING)与DataLoader

合并小图组成大图

图可以有任意数量的节点和边,它不是规整的数据结构,因此对图数据封装成批的操作与对图像与序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:
KaTeX parse error: No such environment: split at position 8: \begin{?s?p?l?i?t?}?\mathbf{A} = \b…
此方法有以下关键的优势

  • 依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。
  • 没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。

对于邻接矩阵,PyTorch Geometric实现了稀疏(sparse)和稠密(dense)的方式,SparseTensor能更高效地实现消息传递范式

通过torch_geometric.data.DataLoader类,多个小图被封装成一个大图。torch_geometric.data.DataLoader是PyTorch的DataLoader的子类,它覆盖了collate()函数,改函数定义了一列表的样本是如何封装成批的。因此,所有可以传递给PyTorch DataLoader的参数也可以传递给PyTorch Geometric的 DataLoader,例如,num_workers

小图的属性增值与拼接

将小图存储到大图中时需要对小图的属性做一些修改,一个最显著的例子就是要对节点序号增值。在最一般的形式中,PyTorch Geometric的DataLoader类会自动对edge_index张量增值,增加的值为当前被处理图的前面的图的累积节点数量。比方说,现在对第 k k k个图的edge_index张量做增值,前面 k ? 1 k-1 k?1个图的累积节点数量为 n n n,那么对第 k k k个图的edge_index张量的增值 n n n。增值后,对所有图的edge_index张量(其形状为[2, num_edges])在第二维中连接起来。

然而,有一些特殊的场景中(如下所述),基于需求我们希望能修改这一行为。PyTorch Geometric允许我们通过覆盖torch_geometric.data.__inc__()torch_geometric.data.__cat_dim__()函数来实现我们希望的行为。在未做修改的情况下,它们在Data类中的定义如下。

def __inc__(self, key, value):
    if 'index' in key or 'face' in key:
        return self.num_nodes
    else:
        return 0

def __cat_dim__(self, key, value):
    if 'index' in key or 'face' in key:
        return 1
    else:
        return 0

我们可以看到,__inc__()定义了两个连续的图的属性之间的增量大小,而__cat_dim__()定义了同一属性的图形张量应该在哪个维度上被连接起来。PyTorch Geometric为存储在Data类中的每个属性调用此二函数,并以它们各自的key和值item作为参数。相关例子可以阅读ADVANCED MINI-BATCHING或者datawhale团队对此教程的翻译

二、创建超大规模数据集类实践

PCQM4M-LSC是一个分子图的量子特性回归数据集,它包含了3,803,453个图。

值得关注的是,该数据集是KDD Cup 2021 OGB-LSC图预测赛道赛道使用的数据集,夺冠队伍使用了可应用于图结构数据的 Graphormer 模型以微弱的优势夺冠,相关报道可以阅读Transformer杀疯了!竟在图神经网络的ImageNet大赛中夺冠,力压DeepMind、百度…,提出Graphormer的论文是Do Transformers Really Perform Bad for Graph Representation?

通过pip install ogb命令可安装ogb包。ogb文档可见于Get Started | Open Graph Benchmark (stanford.edu)

我们定义的数据集类如下:

import os
import os.path as osp

import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
from torch_geometric.data import Data, Dataset
import shutil

RDLogger.DisableLog('rdApp.*')

class MyPCQM4MDataset(Dataset):

    def __init__(self, root):
        self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
        super(MyPCQM4MDataset, self).__init__(root)

        filepath = osp.join(root, 'raw/data.csv.gz')
        data_df = pd.read_csv(filepath)
        self.smiles_list = data_df['smiles']
        self.homolumogap_list = data_df['homolumogap']

    @property
    def raw_file_names(self):
        return 'data.csv.gz'

    def download(self):
        path = download_url(self.url, self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz'))

    def len(self):
        return len(self.smiles_list)

    def get(self, idx):
        smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
        graph = smiles2graph(smiles)
        assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
        assert(len(graph['node_feat']) == graph['num_nodes'])

        x = torch.from_numpy(graph['node_feat']).to(torch.int64)
        edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
        edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
        y = torch.Tensor([homolumogap])
        num_nodes = int(graph['num_nodes'])
        data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
        return data

    # 获取数据集划分
    def get_idx_split(self):
        split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
        return split_dict

if __name__ == "__main__":
    dataset = MyPCQM4MDataset('dataset2')
    from torch_geometric.data import DataLoader
    from tqdm import tqdm
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
    for batch in tqdm(dataloader):
        pass

在生成一个该数据集类的对象时,程序首先会检查指定的文件夹下是否存在data.csv.gz文件,如果不在,则会执行download方法,这一过程是在运行super类的__init__方法中发生的。然后程序继续执行__init__方法的剩余部分,读取data.csv.gz文件,获取存储图信息的smiles格式的字符串,以及回归预测的目标homolumogap。我们将由smiles格式的字符串转成图的过程在get()方法中实现,这样我们在生成一个DataLoader变量时,通过指定num_workers可以实现并行执行生成多个图。

三、图预测任务实践

相关代码见gin_regression

在这里插入图片描述
实验室的服务器跑的非常慢,后面会继续跟进。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-10 14:32:45  更:2021-07-10 14:33:25 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 10:12:21-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码