数据集存放
将数据集分类存放,不同类别放在不同目录下,目录名即标签,数据集存放格式如下: root/ants/001.png root/ants/002.png root/ants/003.png … root/bees/001.png root/bees/002.png root/bees/003.png …
图像数据预处理
1、使用transforms设置图像预处理操作
设置剪裁、缩放、翻转等参数,详细参考链接transforms
data_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
2、使用 datasets.ImageFolder对指定位置图像进行如上预处理操作
root处写入本地数据集地址,对该地址的图像进行上文data_transform设置的图像预处理操作
train_dataset = datasets.ImageFolder(root='Dataset/train', transform=data_transform)
图像数据集加载
1、使用torch.utils.data.DataLoader加载数据集
加载预处理好的数据集train_dataset,每批数据量大小为64
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64)
源码如下:
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder(root='Dataset/train', transform=data_transform)
train_data_size = len(train_dataset)
print("训练数据集长度为:{}".format(train_data_size))
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64)
for img, label in train_dataloader :
print("图像img的形状{},标签label的值{}".format(img.shape, label))
print("图像数据预处理后:\n",img)
参考文章: transforms
|