Task07:超大规模数据集类的创建&图预测任务实践
本文参考datawhale开源学习资料
一、超大规模数据集类的创建
1. Dataset 基类简介
torch_geometric.data.Dataset 需要实现另外的两个方法:
len() :返回数据集中的样本的数量。get() :实现加载单个图的操作。注意:在内部,__getitem__() 返回通过调用get() 来获取Data 对象,并根据transform 参数对它们进行选择性转换。
下面让我们通过一个简化的例子看继承torch_geometric.data.Dataset 基类的规范:
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
path = download_url(url, self.raw_dir)
...
def process(self):
i = 0
for raw_path in self.raw_paths:
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
其中,每个Data 对象在process() 方法中单独被保存,并在get() 中通过指定索引进行加载。
FAQ
from torch_geometric.data import Data, DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
torch_geometric.data.DataLoader 继承自torch.utils.data.DataLoader ,通过torch_geometric.data.DataLoader 可以方便地使用 mini-batch。
我们也可以通过下面的方式将一个列表的Data 对象组成一个Batch :
from torch_geometric.data import Data, Batch
data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, batch_size=32)
torch_geometric.data.Batch 继承自torch_geometric.data.Data ,并且多了一个属性:batch 。batch 是一个列向量,它将每个元素映射到每个 mini-batch 中的相应图:batch
=
[
0
?
0
1
?
n
?
2
n
?
1
?
n
?
1
]
?
=\left[\begin{array}{cccccccc}0 & \cdots & 0 & 1 & \cdots & n-2 & n-1 & \cdots & n-1\end{array}\right]^{\top}
=[0???0?1???n?2?n?1???n?1?]?。这种记录对于图输出阶段是非常有必要的,相同索引的节点属于同一个图,如何经过torch_scatter 进行节点表示的sum,mean或max等操作。Batch 还有个方法to_data_list ,即通过Batch 重构出一个包含Data 的列表,但是该方法的使用前提是Batch 是由from_data_list 创建而成的。
2. 图样本封装成批(BATCHING)与DataLoader 类
合并小图组成大图
图可以有任意数量的节点和边,它不是规整的数据结构,因此对图数据封装成批的操作与对图像与序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵、预测目标矩阵分别为: KaTeX parse error: No such environment: split at position 8: \begin{?s?p?l?i?t?}?\mathbf{A} = \b… 此方法有以下关键的优势:
- 依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。
- 没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。
对于邻接矩阵,PyTorch Geometric实现了稀疏(sparse)和稠密(dense)的方式,SparseTensor 能更高效地实现消息传递范式。
通过torch_geometric.data.DataLoader 类,多个小图被封装成一个大图。torch_geometric.data.DataLoader 是PyTorch的DataLoader 的子类,它覆盖了collate() 函数,改函数定义了一列表的样本是如何封装成批的。因此,所有可以传递给PyTorch DataLoader 的参数也可以传递给PyTorch Geometric的 DataLoader ,例如,num_workers 。
小图的属性增值与拼接
将小图存储到大图中时需要对小图的属性做一些修改,一个最显著的例子就是要对节点序号增值。在最一般的形式中,PyTorch Geometric的DataLoader 类会自动对edge_index 张量增值,增加的值为当前被处理图的前面的图的累积节点数量。比方说,现在对第
k
k
k个图的edge_index 张量做增值,前面
k
?
1
k-1
k?1个图的累积节点数量为
n
n
n,那么对第
k
k
k个图的edge_index 张量的增值
n
n
n。增值后,对所有图的edge_index 张量(其形状为[2, num_edges] )在第二维中连接起来。
然而,有一些特殊的场景中(如下所述),基于需求我们希望能修改这一行为。PyTorch Geometric允许我们通过覆盖torch_geometric.data.__inc__() 和torch_geometric.data.__cat_dim__() 函数来实现我们希望的行为。在未做修改的情况下,它们在Data 类中的定义如下。
def __inc__(self, key, value):
if 'index' in key or 'face' in key:
return self.num_nodes
else:
return 0
def __cat_dim__(self, key, value):
if 'index' in key or 'face' in key:
return 1
else:
return 0
我们可以看到,__inc__() 定义了两个连续的图的属性之间的增量大小,而__cat_dim__() 定义了同一属性的图形张量应该在哪个维度上被连接起来。PyTorch Geometric为存储在Data 类中的每个属性调用此二函数,并以它们各自的key 和值item 作为参数。相关例子可以阅读ADVANCED MINI-BATCHING或者datawhale团队对此教程的翻译。
二、创建超大规模数据集类实践
PCQM4M-LSC是一个分子图的量子特性回归数据集,它包含了3,803,453个图。
值得关注的是,该数据集是KDD Cup 2021 OGB-LSC图预测赛道赛道使用的数据集,夺冠队伍使用了可应用于图结构数据的 Graphormer 模型以微弱的优势夺冠,相关报道可以阅读Transformer杀疯了!竟在图神经网络的ImageNet大赛中夺冠,力压DeepMind、百度…,提出Graphormer的论文是Do Transformers Really Perform Bad for Graph Representation?。
通过pip install ogb 命令可安装ogb 包。ogb 文档可见于Get Started | Open Graph Benchmark (stanford.edu)。
我们定义的数据集类如下:
import os
import os.path as osp
import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
from torch_geometric.data import Data, Dataset
import shutil
RDLogger.DisableLog('rdApp.*')
class MyPCQM4MDataset(Dataset):
def __init__(self, root):
self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
super(MyPCQM4MDataset, self).__init__(root)
filepath = osp.join(root, 'raw/data.csv.gz')
data_df = pd.read_csv(filepath)
self.smiles_list = data_df['smiles']
self.homolumogap_list = data_df['homolumogap']
@property
def raw_file_names(self):
return 'data.csv.gz'
def download(self):
path = download_url(self.url, self.root)
extract_zip(path, self.root)
os.unlink(path)
shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz'))
def len(self):
return len(self.smiles_list)
def get(self, idx):
smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
graph = smiles2graph(smiles)
assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
assert(len(graph['node_feat']) == graph['num_nodes'])
x = torch.from_numpy(graph['node_feat']).to(torch.int64)
edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
y = torch.Tensor([homolumogap])
num_nodes = int(graph['num_nodes'])
data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
return data
def get_idx_split(self):
split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
return split_dict
if __name__ == "__main__":
dataset = MyPCQM4MDataset('dataset2')
from torch_geometric.data import DataLoader
from tqdm import tqdm
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
for batch in tqdm(dataloader):
pass
在生成一个该数据集类的对象时,程序首先会检查指定的文件夹下是否存在data.csv.gz 文件,如果不在,则会执行download 方法,这一过程是在运行super 类的__init__ 方法中发生的。然后程序继续执行__init__ 方法的剩余部分,读取data.csv.gz 文件,获取存储图信息的smiles 格式的字符串,以及回归预测的目标homolumogap 。我们将由smiles 格式的字符串转成图的过程在get() 方法中实现,这样我们在生成一个DataLoader 变量时,通过指定num_workers 可以实现并行执行生成多个图。
三、图预测任务实践
相关代码见gin_regression。
实验室的服务器跑的非常慢,后面会继续跟进。
|