pytorch在torch.utils.data 对常用数据加载进行了封装,可用很容易实现数据读取和批量加载。
DataSet
DataSet 是 pytorch 提供的用于包装数据的抽象类。可以通过继承并实现其中的抽象方法来自定义DataSet。
自定义DataSet要继承并实现两个成员方法:
__getitem__(self, idx) :该方法需要实现通过索引获得一条数据。__len__(self) :该方法需要返回数据集的长度。
DataLoader
DataLoader提供了对DataSet的读取操作,常用参数有:
batch_size :每个批次的大小。shuffle :是否对数据进行洗牌操作。num_work :加载数据时使用几个子进程。
示例
在样例中,生成了
[
0
,
400
)
[0, 400)
[0,400)的整数的序列,并将其转换为
[
100
,
2
,
2
]
[100, 2, 2]
[100,2,2] 的矩阵。
示例代码:
from torch.utils.data import Dataset
import torch
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
if __name__ == '__main__':
data = torch.range(0, 399)
data = data.view(100, 2, 2)
ds = MyDataset(data)
print('数据集长度:', len(ds))
print('ds[0]: ', ds[0])
dl = torch.utils.data.DataLoader(
ds, batch_size=2, shuffle=True, num_workers=0
)
itDl = iter(dl)
print('next(itDl): ', next(itDl))
输出结果:
数据集长度: 100
ds[0]: tensor([[0., 1.],
[2., 3.]])
next(itDl): tensor([[[276., 277.],
[278., 279.]],
[[284., 285.],
[286., 287.]]])
从结果可以看出,DataLoader对数据进行了洗牌,并以每批次2个数据输出。
注意:
这里的MyDataSet 仅做了简单实现,并不一定只能传入数据,也可以传入文件路径等,然后对数据进行读取并保存。
如:
class BulldozerDataset(Dataset):
""" 数据集演示 """
def __init__(self, csv_file):
"""实现初始化方法,在初始化的时候将数据读载入"""
self.df=pd.read_csv(csv_file)
def __len__(self):
'''
返回df的长度
'''
return len(self.df)
def __getitem__(self, idx):
'''
根据 idx 返回一行数据
'''
return self.df.iloc[idx].SalePrice
|