'''
Dataset:
提供读取数据和其标签的方式:
- 获取每条数据和标签
- 告诉我们总共有多少条数据
'''
from torch.utils.data import Dataset
from PIL import Image
import os
class DataSet(Dataset):
def __init__(self, type):
self.root_path = os.getcwd()
self.type = type
self.paths = os.path.join(self.root_path, '数据集\\hymenoptera_data\\hymenoptera_data\\train', self.type)
def __getitem__(self, idx):
img_path = os.path.join(self.paths, os.listdir(self.paths)[idx])
img = Image.open(img_path)
label = self.type
return img, label
def __len__(self):
return len(os.listdir(self.paths))
ant_dataset = DataSet('ants')
bee_dataset = DataSet('bees')
data = ant_dataset + bee_dataset
'''
Dataset:
提供读取数据和其标签的方式:
- 获取每条数据和标签
- 告诉我们总共有多少条数据
'''
from torch.utils.data import Dataset
from PIL import Image
import os
class DataSet(Dataset):
def __init__(self, type):
self.root_path = os.getcwd()
self.type = type
self.paths = os.path.join(self.root_path, '数据集\\练手数据集\\train', self.type)
def __getitem__(self, idx):
img_name = os.listdir(self.paths)[idx]
img_path = os.path.join(self.paths, img_name)
img = Image.open(img_path)
with open(os.path.join(os.path.join(self.root_path, '数据集\\练手数据集\\train'), 'bees_label', '{}.txt'.format(img_name.strip('.jpg'))), mode='rt', encoding='utf8') as f:
label = f.read().strip()
return img, label
def __len__(self):
return len(os.listdir(self.paths))
ant_dataset = DataSet('bees_image')
print(ant_dataset[0])
|