torch大神请忽略此文。。。
1,一个简单例子回顾DataSet
from torch.utils.data import Dataset
class dataset(Dataset):
def __init__(self):
self.data = np.array([[1,1,1,1],
[2,2,2,2],
[3,3,3,3],
[4,4,4,4],
[5,5,5,5],
[6,6,6,6]])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
dataset = dataset()
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=2,
shuffle=False,
num_workers=1)
for i, data_ in enumerate(dataloader):
print(i)
print(data_)
观察运行结果:
0
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
1
tensor([[3, 3, 3, 3],
[4, 4, 4, 4]])
2
tensor([[5, 5, 5, 5],
[6, 6, 6, 6]])
这个例子足够理解DataSet了
2,维度
参考这篇文章: https://blog.csdn.net/xddwz/article/details/108405817
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
import os
import cv2
from PIL import Image
class MyDataset(Dataset):
def __init__(self, transform=None):
self.transform = transforms.Compose([
transforms.ToTensor()
])
self.image_path = './image_data2/'
self.image_names = os.listdir(self.image_path)
def __len__(self):
return len(self.image_names)
def __getitem__(self, item):
image_name = self.image_names[item]
image = cv2.imread(os.path.join(self.image_path, image_name))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = torch.from_numpy(image).permute(2, 0, 1)
return image
这段代码是上述链接中的,搬运过来的原因是强调一个事情:我们知道torch的维度顺序是BCHW,而上述代码中的__getitem__() ,是返回一张图片,那么这个时候我们需要注意的是,单张图片本来的维度顺序是HWC,即维度是(height, width, channel),我们需要将它的维度调整为(channel, height, width),然后再返回。
同时,根据上述代码,也需要注意,返回的image的shape不是(1, channel, height, width),而是(channel, height, width),batch对应的那一维在__getitem__() 中不需要考虑。
对于一些channel为1的图片,如果需要增加channel维度,那么只需要squeeze(0)就行了。
|