ImageFolder是torchvision库下的一个常用的加载图片数据集的方式。
图片数据集需要以这样的方式进行组织:将不同类别的数据分别放在不同目录下,每个目录的名字就是数据的标签。存放方式如下图所示
├──cat
| ├──cat_001.jpg
| ├──cat_002.jpg
| └──……
├──dog
| ├──dog_001.jpg
| ├──dog_002.jpg
| ├──dog_003.jpg
| └──……
└──……
随便找了几张图片作为这个示例。
ImgaFolder函数原型如下
torchvision.datasets.ImageFolder(root, transform=None, target_transform=None
loader=default_loader, is_valid_file=)
参数的含义如下
- root:string类型。存放数据集的根目录。
- transform:function类型,输入为loader读取的PIL图片,可选。对输入的图片进行transform变换,返回转换后的图片。
- target_transform:function类型,输入为label,可选。对label进行transform变换。
- loader:function类型,可选。加载图片的方式,默认读取RGB格式的PIL Image类型的图像。
- is_valid_file: function类型,可选。用来检查一个Image文件是否是空的(检查该图片是否损坏)
返回Dataset类型。
label就是一张图像的标签,也就是将子文件夹的名字以字典的形式存储起来,即{类名 : 序号},序号从0开始顺序向后计数。上文中的例子中的label就是
{cat : 0}
{dog : 1}
……
常用的成员变量为
- classes:list类型,类名
- class_to_idx:list类型,label。
- imgs:list类型,由图片名称和它的类型组成。
from torchvision.datasets import ImageFolder
dataset = ImageFolder("data/")
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)
输出为 dataset[i]可以得到对应图像的数据。dataset[i][0]是PIL图像类型,dataset[i][1]是该图像的类别。
print(dataset[0])
输出为
|