准备数据
需要从数据集地址:https://ai.stanford.edu/~jkrause/cars/car_dataset.html下载3个压缩包,下载好之后解压
根据标签划分数据集
import scipy.io as scio
import os.path as osp
dataset_path = "XXX"
train_img_path = osp.join(dataset_path, "cars_train")
test_img_path = osp.join(dataset_path, "cars_test")
label_path = osp.join(dataset_path, "car_devkit", "devkit")
train_label_path = osp.join(label_path, "cars_train_annos.mat")
test_label_path = osp.join(label_path, "cars_test_annos.mat")
img = 0
def devide_dataset_by_label(path):
global img
data = scio.loadmat(path)
data = data["annotations"]
data = data.squeeze()
labels = np.zeros(data.shape[0])
for number, label in enumerate(labels):
if not osp.exists(osp.join(dataset_path, "classes", str(int(label)))):
os.makedirs(osp.join(dataset_path, "classes", str(int(label))))
os.system(
f"cp {osp.abspath(osp.join(train_img_path, str(int(number+1)).zfill(5) + '.jpg'))} {osp.abspath(osp.join(dataset_path, 'classes', str(int(label))))}")
img += 1
if __name__ == '__main__':
devide_dataset_by_label(train_label_path)
devide_dataset_by_label(test_label_path)
print(f"devide {img} imgs")
一共给16185张图片分了类,和数据集官方数据一样
加载成dataloader
def get_Stanford_Cars_dataloader(mode="train", way=5, shot=2, query=10):
if not osp.exists(osp.abspath(osp.join(dataset_path, "classes"))):
devide_dataset_by_label(train_label_path)
devide_dataset_by_label(test_label_path)
print(f"devide {img} imgs")
classes_path = osp.abspath(osp.join(dataset_path, "classes"))
class_list = []
for class_name in os.listdir(classes_path):
if class_name.__contains__("DS_Store"):
continue
class_list.append(os.path.join(classes_path, class_name))
class_names = []
for i in class_list:
if os.listdir(i).__len__() >= query + shot:
class_names.append(i)
train_class_lists = class_names[:int(class_names.__len__() * 0.6)]
val_class_lists = class_names[int(class_names.__len__() * 0.6):int(class_names.__len__() * 0.8)]
test_class_lists = class_names[int(class_names.__len__() * 0.8):]
transforms = [partial(convert_dict, "class"),
partial(load_class_images, 64),
partial(extract_episode, shot, query)]
transforms = compose(transforms)
lists = []
if mode == "train":
lists = train_class_lists
elif mode == "val":
lists = val_class_lists
elif mode == "test":
lists = test_class_lists
episode = int(len(lists) / way) + 1
ds = TransformDataset(ListDataset(lists),
transforms)
sampler = EpisodicBatchSampler(len(ds), way, episode)
dataloader = torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=0)
return dataloader
全部代码
分享至码云
|