数据集说明
数据集一共包含3个目录train 、valid 和test ,每个目录都包含了12生肖(类别)的图片,通过下面的链接可以直接下载数据集
数据下载地址:下载地址 项目地址:项目链接
数据分析
统计数据集中每个类别的数据分布情况
import os
def print_classes_info(mode="train",data_dir = "data/signs"):
datasets_dir = os.path.join(data_dir,mode)
classes_names = os.listdir(datasets_dir)
classes_num_infos = dict()
for class_name in classes_names:
class_dir_path = os.path.join(datasets_dir,class_name)
img_names = os.listdir(class_dir_path)
classes_num_infos[class_name] = len(img_names)
print("{}:{}".format(mode,classes_num_infos))
print_classes_info("train")
print_classes_info("valid")
print_classes_info("test")
train:{‘goat’: 600, ‘tiger’: 600, ‘horse’: 600, ‘snake’: 600, ‘pig’: 600, ‘dragon’: 600, ‘ox’: 600, ‘monkey’: 600, ‘rabbit’: 600, ‘rooster’: 600, ‘dog’: 600, ‘ratt’: 600} valid:{‘goat’: 55, ‘tiger’: 55, ‘horse’: 55, ‘snake’: 55, ‘pig’: 55, ‘dragon’: 55, ‘ox’: 55, ‘monkey’: 55, ‘rabbit’: 55, ‘rooster’: 55, ‘dog’: 55, ‘ratt’: 55} test:{‘goat’: 55, ‘tiger’: 55, ‘horse’: 55, ‘snake’: 55, ‘pig’: 55, ‘dragon’: 55, ‘ox’: 55, ‘monkey’: 55, ‘rabbit’: 55, ‘rooster’: 55, ‘dog’: 55, ‘ratt’: 55}
在训练集中每个类别包含600张 图片,验证集中每个类别包含55张 图片,测试集中每个类别包含55 张图片,因为这里的数据都比较平衡,后面我们就不需要去考虑数据的平衡问题了。
展示图片
从数据集中选择一部分数据进行查看,了解数据的分布特征
import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
def show_images(mode="train",data_dir = "data/signs",row_num=3,col_num=4):
datasets_dir = os.path.join(data_dir,mode)
classes_names = os.listdir(datasets_dir)
images_list = []
title_list = []
for cls_name in classes_names:
cls_dir_path = os.path.join(datasets_dir,cls_name)
img_name_list = os.listdir(cls_dir_path)
for img_name in img_name_list:
img_path = os.path.join(cls_dir_path,img_name)
image = Image.open(img_path)
images_list.append(np.array(image))
title_list.append(cls_name)
break
plt.figure(figsize=(8,8))
for index in range(row_num*col_num):
plt.subplot(row_num,col_num,index+1)
plt.imshow(images_list[index])
plt.title(title_list[index])
plt.xticks([])
plt.yticks([])
plt.show()
show_images()
数据加载器
基于paddlepaddle提供的paddle.io.Dataset 类,封装一个十二生肖的数据加载器,用于后面的模型训练和评估,将图片的预处理也封装在里面
import os
import paddle
from paddle.vision import transforms
from PIL import Image
import numpy as np
class ZodiacDatasets(paddle.io.Dataset):
"""
加载十二生肖数据
"""
def __init__(self,mode="train",data_root="data/signs",img_size=(224,224)):
self.data_root = data_root
if mode not in ["train","valid","test"]:
assert("{} is illegal,mode need is one of train,valid,test")
self._data_dir_path = os.path.join(data_root,mode)
self._zodiac_names = sorted(os.listdir(self._data_dir_path))
self._img_path_list = []
for name in self._zodiac_names:
img_dir_path = os.path.join(self._data_dir_path,name)
img_name_list = os.listdir(img_dir_path)
for img_name in img_name_list:
img_path = os.path.join(img_dir_path,img_name)
self._img_path_list.append(img_path)
if mode == "train":
self._transform = transforms.Compose([
transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
else:
self._transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
def __getitem__(self,index):
"""根据index获取图片数据
"""
img_path = self._img_path_list[index]
img_label = img_path.split("/")[-2]
label_index = self._zodiac_names.index(img_label)
img = Image.open(img_path)
if img.mode != "RGB":
img = img.convert("RGB")
img = self._transform(img)
return img,np.array(label_index,dtype=np.int64)
def __len__(self):
"""获取数据集的大小
"""
return len(self._img_path_list)
train_datasets = ZodiacDatasets(mode="train")
print(len(train_datasets))
for img,img_label in train_datasets:
print(img.shape,img_label)
break
|