dataset
通过 from torch.utils.data import IterableDataset, DataLoader 引入. 典型用法如 loader=DataLoader(dataset:Dataset) , 二者一起等价于 tf 的 dataset.
流式数据
详见 参考[1]. 典型场景是内存盛不下, 网络数据库->dataset->model feed 流式运作. 例子见下.
from torch.utils.data import IterableDataset, DataLoader
class StreamingDataset(IterableDataset):
i = 0
def generator(self):
while True:
StreamingDataset.i += 1
yield StreamingDataset.i
def __iter__(self):
return iter(self.generator())
loader = DataLoader(StreamingDataset())
it = iter(loader)
print(next(it))
print(next(it))
"""
tensor([1])
tensor([2])
"""
参考
- a-streaming-dataloader
|