一、torchvision工具包
- VisionDataset类:存储数据集所在的根目录、训练数据和预测目标的变换函数(transform和target_transform),但是没有实现__getitem__和__len__方法。
- DatasetFloder类:继承了VisionDataset类,实现__getitem__和__len__方法。
- 使用时,数据集存储在一个目录下,而且含有多个子目录。每个子目录是一类图片。
- 构造函数一开始调用内部的_find_classes找到具体预测目标的类别,和类别对应的整数(class_to_idx实现)。得到一个记录数据路径和数据预测目标的列表。另外还会传入参数loader,用来载入数据。
- _getitem__传入index,根据index从self.samples一条数据记录的parh和target。使用loader载入数据,self.transform和self.target_transform进行数据变换。最后返回变换之后的数据和预测目标。
代码如下:
""" 该代码仅为演示类的构造方法所用,并不能实际运行
class VisionDataset(data.Dataset):
def __init__(self, root, transforms=None, transform=None,
target_transform=None):
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class DatasetFolder(VisionDataset):
def __init__(self, root, loader, extensions=None, transform=None,
target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
- torchvision.transforms包:包含内置的转换类,含有两种:
- 作用于PIL.Image类上:通过PIL.Image.open打开一张图片,返回PIL.Image类。
- torchvision.transforms.CenterCrop类:输入整数或元组(长宽),从图片中截取一个正方形区域,返回新的PIL.Image类
- torchvision.transforms.Resize:缩放图像。输入整数或元组,代表缩放目标图像的大小。
- torchvision.transforms.ToTensor:将图像从 0-255之间的整数值,转换为0-1之间的浮点数张量。
- 作用于PyTorch张量上
- torchvision.transforms.ToPILImage类:将浮点数张量转换为PIL.Image图像。
- torchvision.transforms.Noemalize:对图片转换后的张量进行标准化。两个参数是所有图片的均值mean张量和标准差张量std。输出是(x-mean)/std
- torchvision.transforms.Compose:构造一个包含列表的转换类,输出列表里的转换依次作用后的结果。
|