torchvision.datasets中包含如下数据集:
- MNIST
- COCO
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR-10,CIFAR-100
- STL10
MNIST
dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)
参数说明:
- - root :?
processed/training.pt ?和?processed/test.pt ?的主目录 - - train:True=训练集,False=测试集
- - download:True=从互联网上下载数据集,并把数据集放在root目录下,如果数据集之前下载过,将处理过的数据(mnist.py中有相关函数)放在processed文件夹下
ImageFolder
一个通用的数据加载器,数据集中的数据以以下方式组织
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
dset.ImageFolder(root="root folder path", [transform, target_transform])
有以下成员变量:
- self.classes - 用一个list保存类名
- self.class_to_idx - 类名对应的索引
- self.imgs - 保存(img-path, class) tuple的list
root是根文件夹目录
torch.utils.data
At the heart of PyTorch data loading utility is the?torch.utils.data.DataLoader?class. It represents a Python iterable over a dataset, with support for
pytorch数据加载的核心是torch.utils.data.DataLoader类,它表示在一个数据集上的Python迭代。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
参数说明:
- dataset:加载数据的数据集
- batch_size:每次加载多少个样本(default=1)
- shuffle:True=每个epoch将数据重新打乱(default=False)
- sampler:定义从数据集中选取数据的采样方法(default=None)
- num_workers:用于数据加载的子进程数。=0表示数据将在主进程中加载。
|