一、加载pytorch自带数据集
torchvison.datasets是torch.utils.data.Dataset的实现。 包括如下数据集: all = (‘LSUN’, ‘LSUNClass’, ‘ImageFolder’, ‘DatasetFolder’, ‘FakeData’, ‘CocoCaptions’, ‘CocoDetection’, ‘CIFAR10’, ‘CIFAR100’, ‘EMNIST’, ‘FashionMNIST’, ‘QMNIST’, ‘MNIST’, ‘KMNIST’, ‘STL10’, ‘SVHN’, ‘PhotoTour’, ‘SEMEION’, ‘Omniglot’, ‘SBU’, ‘Flickr8k’, ‘Flickr30k’, ‘VOCSegmentation’, ‘VOCDetection’, ‘Cityscapes’, ‘ImageNet’, ‘Caltech101’, ‘Caltech256’, ‘CelebA’, ‘SBDataset’, ‘VisionDataset’, ‘USPS’, ‘Kinetics400’, ‘HMDB51’, ‘UCF101’, ‘Places365’)
1.使用torchvision.datasets加载数据集
import torch
import torchvision
from PIL import Image
cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)
2.使用torch.utils.data.DataLoader来实例化
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)
3.测试
for i, data in enumerate(cifarLoader, 0):
print(data[i][0])
# PIL
img = transforms.ToPILImage()(data[i][0])
img.show()
break
二、加载个人的数据集
1.继承Dataset类,生成数据集
import torch.utils.data as data
#定义myDataSet类来继承Dataset
#generate train_data or test_data...
def default_loader(path):
return Image.open(path).convert('RGB')
class myDataSet(data.Dataset):
""""
@:param
label_txt:每个图像名称以及路径,one image one line
"""
def __init__(self,label_txt,transform = None,target_transform = None, loader=default_loader):
super(myDataSet, self).__init__()
self.imgs = []
self.transform =transform
self.target_transform = target_transform
self.loader =loader
fn = open(label_txt,'r')
imgs=[]
for line in fn:
line = line.strip('\n')
line = line.rstrip('\n')
words = line.split()
imgs.append(words[0])
self.imgs = imgs
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
fn = self.img[index]
img = self.loader(os.path.join(self.root,fn))
return img
label_txt的格式如下: 每一行是一个图像的绝对路径 同时,需要重写__len__与__getitem__两个函数如上
2.加载数据集
def get_my_data():
train_data = myDataSet(label_txt='',transforms=transform.ToTensor())
test_data = myDataSet(label_txt='', transforms=transform.ToTensor())
train_loader = DataLoader(train_data,shuffle=True,batch_size=BATCH_SIZE,num_workers=1)
#test_loader = DataLoader(test_data, shuffle=False, batch_size=BATCH_SIZE, num_workers=1)
return train_loader
参考文献: https://blog.csdn.net/sinat_42239797/article/details/90641659 https://zhuanlan.zhihu.com/p/27434001
|