参考:https://zhuanlan.zhihu.com/p/52807406?还是比较喜欢这个作者把源码放出来,就是继承 Dataset需要覆盖什么,这样就可以了,如下图:
所以,知道需要overide什么,我们就可以定义自己的 CustomData了,如下所示:
import random
import torchvision.transforms as transforms
import cv2
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from skimage import io, transform
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose(
[
transforms.Tensor(),
normalize
]
)
class CustomData(Dataset):
def __init__(self,file_path, transform=None):
super(CustomData, self).__init__()
self.file_path = file_path
self.transform = transform
self.image_names = os.listdir(self.file_path)
def __getitem__(self, item):
image_name = self.image_names[item]
image = cv2.imread(os.path.join(self.file_path, image_name))
# if self.transform:
# image = self.transform(image)
return image
def __len__(self):
return len(self.image_names)
image_loader = CustomData(file_path=r'D:\data\1', transform=preprocess)
img = image_loader.__getitem__(0)
cv2.imshow('img', img)
cv2.waitKey(0)
其实:
1.先初始化好,把图片路径和target弄好,以及对应上
2.然后定义“__getitem__”,这个方法就是给一个item,然后获取到对应的图片和target(标签)
3.定义“__len__”,获取到整个图片的数量
|