PyTorch深度学习数据相关
学习笔记
数据集创建 data.Dataset
pytorch的数据集创建使用的是data.Dataset函数以及可能使用到的库:
from torch.utils import data
from torchvision import transforms as T
继承data.Dataset类创建数据集主要包含两个关键def 用以分别保证每次循环中DataLoader读取的内容以及循环的次数:
class mydataset(data.Dataset):
def __init__(self, path, transforms=None):
self.imgs = get_file_paths(path)
self.n_samples = len(self.imgs)
if transforms is not None:
normalize = T.Normalize(mean=[], std=[])
self.transforms = T.Compose([
T.Scale(224),
T.Centercrop(224),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
img = Image.open(imgs[index]).convert('RGB')
label = imgs[index].split('')[]
if self.transforms is not None:
img = self.transformer(img)
return img, label
def __len__(self):
return self.n_samples
获取文件路径的函数:
def get_file_paths(path):
file_paths = []
for file_name in os.listdir(path):
file_path = os.path.join(path, file_name)
file_parhs.append(file_path)
file_paths = sorted(file_paths)
return file_paths
这样就创建了一个可供DataLoader调用的Dataset类
数据载入 data.DataLoader
首先,需要实例化自己的数据集:
train_dataset = mydataset(opt.data_path, transforms=True)
接着载入该数据集:
train_loader = data.DataLoader(
train_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.num_workers
)
之后就可以进行循环读取img和label了:
for ii, (data, label) in tqdm(train_loader):
train()
数据预处理 torchvision.transforms
transforms在定义之后可以直接在dataset中使用,这里仅列出常用的transforms:
torchvision.transforms.Compose(transforms)
transforms.Resize()
transforms.CenterCrop()
transforms.ToTensor()
transforms.Normalize(mean=[], std=[])
|