IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> paddlepaddle实现十二生肖的分类之数据的预处理(一) -> 正文阅读

[人工智能]paddlepaddle实现十二生肖的分类之数据的预处理(一)

数据集说明

数据集一共包含3个目录trainvalidtest,每个目录都包含了12生肖(类别)的图片,通过下面的链接可以直接下载数据集

数据下载地址:下载地址
项目地址:项目链接

数据分析

统计数据集中每个类别的数据分布情况

import os

def print_classes_info(mode="train",data_dir = "data/signs"):
    datasets_dir = os.path.join(data_dir,mode)
    classes_names = os.listdir(datasets_dir)
    #用来保存每个类别的数量信息
    classes_num_infos = dict()
    for class_name in classes_names:
        #获取类别的目录
        class_dir_path = os.path.join(datasets_dir,class_name)
        img_names = os.listdir(class_dir_path)
        #记录每个类别的图片数量
        classes_num_infos[class_name] = len(img_names)
    print("{}:{}".format(mode,classes_num_infos))

#打印数据的分布情况
print_classes_info("train")
print_classes_info("valid")
print_classes_info("test")

train:{‘goat’: 600, ‘tiger’: 600, ‘horse’: 600, ‘snake’: 600, ‘pig’: 600, ‘dragon’: 600, ‘ox’: 600, ‘monkey’: 600, ‘rabbit’: 600, ‘rooster’: 600, ‘dog’: 600, ‘ratt’: 600}
valid:{‘goat’: 55, ‘tiger’: 55, ‘horse’: 55, ‘snake’: 55, ‘pig’: 55, ‘dragon’: 55, ‘ox’: 55, ‘monkey’: 55, ‘rabbit’: 55, ‘rooster’: 55, ‘dog’: 55, ‘ratt’: 55}
test:{‘goat’: 55, ‘tiger’: 55, ‘horse’: 55, ‘snake’: 55, ‘pig’: 55, ‘dragon’: 55, ‘ox’: 55, ‘monkey’: 55, ‘rabbit’: 55, ‘rooster’: 55, ‘dog’: 55, ‘ratt’: 55}

在训练集中每个类别包含600张图片,验证集中每个类别包含55张图片,测试集中每个类别包含55张图片,因为这里的数据都比较平衡,后面我们就不需要去考虑数据的平衡问题了。

展示图片

从数据集中选择一部分数据进行查看,了解数据的分布特征

import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

def show_images(mode="train",data_dir = "data/signs",row_num=3,col_num=4):
    datasets_dir = os.path.join(data_dir,mode)
    #获取所有的类别数据
    classes_names = os.listdir(datasets_dir)
    #用来保存图片数据
    images_list = []
    title_list = []
    for cls_name in classes_names:
        cls_dir_path = os.path.join(datasets_dir,cls_name)
        img_name_list = os.listdir(cls_dir_path)
        for img_name in img_name_list:
            img_path = os.path.join(cls_dir_path,img_name)
            image = Image.open(img_path)
            #保存图片和标签数据
            images_list.append(np.array(image))
            title_list.append(cls_name)
            break
            
    #设置图片的大小
    plt.figure(figsize=(8,8))
    for index in range(row_num*col_num):
        plt.subplot(row_num,col_num,index+1)
        plt.imshow(images_list[index])
        plt.title(title_list[index])
        #隐藏x轴和y轴的标签刻度
        plt.xticks([])
        plt.yticks([])
    
    plt.show()


show_images()

在这里插入图片描述

数据加载器

基于paddlepaddle提供的paddle.io.Dataset类,封装一个十二生肖的数据加载器,用于后面的模型训练和评估,将图片的预处理也封装在里面

import os
import paddle
from paddle.vision import transforms
from PIL import Image
import numpy as np

class ZodiacDatasets(paddle.io.Dataset):
    """
    加载十二生肖数据
    """
    def __init__(self,mode="train",data_root="data/signs",img_size=(224,224)):
        self.data_root = data_root
        #判断mode是否正确
        if mode not in ["train","valid","test"]:
            assert("{} is illegal,mode need is one of train,valid,test")
        #获取数据集的目录
        self._data_dir_path = os.path.join(data_root,mode)
        #获取十二生肖的类别名称
        self._zodiac_names = sorted(os.listdir(self._data_dir_path))
        #用来保存图片的路径
        self._img_path_list = []
        for name in self._zodiac_names:
            img_dir_path = os.path.join(self._data_dir_path,name)
            img_name_list = os.listdir(img_dir_path)
            for img_name in img_name_list:
                img_path = os.path.join(img_dir_path,img_name)
                self._img_path_list.append(img_path)
        #定义图像的预处理函数
        if mode == "train":
            self._transform = transforms.Compose([
                transforms.RandomResizedCrop(img_size),   #缩放图片并随机裁剪图片为指定shape
                transforms.RandomHorizontalFlip(0.5),     #随机水平翻转图片的概率为0.5
                transforms.ToTensor(),                    #转换图片的格式由HWC ==> CHW
                transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])  #图片通道像素的标准化
            ])
        else:
            self._transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
            ])
    def __getitem__(self,index):
        """根据index获取图片数据
        """
        #获取图片的路径
        img_path = self._img_path_list[index]
        #获取图片的标签
        img_label = img_path.split("/")[-2]
        #将生肖的标签名称转换为数字标签
        label_index = self._zodiac_names.index(img_label)
        #读取图片
        img = Image.open(img_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        #图片的预处理
        img = self._transform(img)
        return img,np.array(label_index,dtype=np.int64)

    def __len__(self):
        """获取数据集的大小
        """
        return len(self._img_path_list)


#加载训练集
train_datasets = ZodiacDatasets(mode="train")
#统计训练集的大小
print(len(train_datasets))
for img,img_label in train_datasets:
    print(img.shape,img_label)
    break
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-03 16:13:20  更:2022-03-03 16:17:38 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 1:41:08-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码