torchvision包提供了一些常用的数据集和转换函数,使用torchvision甚至不需要自己写处理函数。
一、对于torchvision提供的数据集
对于这一类数据集,PyTorch已经帮我们做好了所有的事情,连数据源都不需要自己下载。 Imagenet,CIFAR10,MNIST等等PyTorch都提供了数据加载的功能,所以可以先看看你要用的数据集是不是这种情况。
比如,加载MNIST数据集:
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
),
batch_size=TRAIN_BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'./data', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
),
batch_size=TEST_BATCH_SIZE, shuffle=False)
二、对于特定结构的数据集
通过torchvision中的通用数据集ImageFolder来完成加载,它假设数据结构为如下:
root/airport/airplane(1).jpg
root/airport/airplane(2).jpg
root/airport/airplane(3).jpg
.
.
.
root/beach/beach(1).jpg
root/beach/beach(2).jpg
root/beach/beach(3)_.jpg
同样
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
train_transforms = transforms.Compose(
[transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=45),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
test_valid_transforms = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
train_directory = config.TRAIN_DATASET_DIR
valid_directory = config.VALID_DATASET_DIR
batch_size = config.BATCH_SIZE
num_classes = config.NUM_CLASSES
train_datasets = datasets.ImageFolder(train_directory, transform=train_transforms)
train_data_size = len(train_datasets)
train_data = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
三、对于普通数据集
定义数据集的类MyDataset,这个类要继承Dataset这个抽象类,然后重写下面的函数:
①__len__: 使得len(dataset)返回数据集的大小;
②__getitem__:使得支持dataset[i]能够返回第i个数据样本的下标操作
通常情况还包括初始函数__init__.
class MyDataset(torch.utils.data.Dataset):
def __init__(self, img_paths, labels, transform=None):
self.img_paths = img_paths
self.labels = labels
self.transform = transform
def __getitem__(self, index):
img_path, label = self.img_paths[index], self.labels[index]
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.img_paths)
train_set = MyDataset(
train_img_paths,
train_labels,
transform=train_transform)
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
sampler=None)
参考:链接: https://www.jianshu.com/p/6e22d21c84be.
待更新…
|