IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> (!详解 Pytorch实战:①)kaggle猫狗数据集二分类:加载(集成/自定义)数据集 -> 正文阅读

[人工智能](!详解 Pytorch实战:①)kaggle猫狗数据集二分类:加载(集成/自定义)数据集

这系列的文章是我对Pytorch入门之后的一个总结,特别是对数据集生成加载这一块加强学习

另外,这里有一些比较常用的数据集,大家可以进行下载:

需要注意的是,本篇文章使用的PyTorch的版本是v0.10.0

《深度学习常用的数据集,包括各种数据跟图像数据》

《kaggle猫狗大战[包含训练(25000张猫狗照片)和测试数据集(12500张猫狗照片)》


目录

一、加载已被集成在Pytorch中的数据集?

1、torchvision和torchvision.datasets:数据集

2、torchvision.transforms和torchvision.transforms.Compose:图像预处理

torchvision.transforms.ToTensor()

torchvision.transforms.Normalize(mean,?std,?inplace=False)?

3、torch.utils.data.DataLoader和torch.utils.data.Dataset

torch.utils.data.DataLoader

torch.utils.data.Dataset

二、加载自定义数据集

1、torchvision.datasets.DatasetFolder

2、torchvision.datasets.ImageFolder

三、加载本地kaggle猫狗数据集


一、加载已被集成在Pytorch中的数据集?

在这第一大块的内容中,可以了解到几个函数:

torchvision和torchvision.datasets

torchvision.transforms和torchvision.transforms.Compose

torch.utils.data.DataLoader / DataLoaderItertorch.utils.data.Dataset

1、torchvision和torchvision.datasets:数据集

《PyTorch如何加载数据集(自定义数据集)》这篇博客通过代码展示了Pytorch加载数据集的两种方法,对于已被集成在Pytoch内的数据集:比如CIFAR-10,CIFAR-100,MNIST等等,此类数据集可以直接使用Pytorch的内置函数(torchvision.datasets.XXX来直接加载)

Pytorch中官方文档对torchvision的解释:torchvision — Torchvision 0.10.0 documentation (pytorch.org)

这个Pytorch中的一个库,torchvision包包含了一些比较流行的数据集(datasets)、模型架构(model architectures)和用于计算机视觉常见的图像转换(image transformations for computer vision)

另外这里还有一些比较常见的库,例如:

  • torchtext:包含数据处理的实用工具和自然语言处理方面的的流行数据集
  • torchaudio:包含一些音频I/O,音频方面的转化和流行的数据集
  • ……

Pytorch中官方文档对torchvision.datasets的解释:torchvision.datasets — Torchvision 0.10.0 documentation (pytorch.org)

所有的数据集都是torch.utils.data.Dataset的子类,就是说这些数据集已经实现了__getitem__()“返回索引”和__len__“返回长度”的方法。因此它们都可以传递给torch.utils.data.Dataset,它可以使用torch.multiprocessing并行加载多个样本

比如:

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')  # 图片集的路径
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

所有的数据集都使用相同的API,同时都用两个常见的参数transform和target_transforms去将输入(X)和标签(Y)分别进行张量转换

不过你也可以使用base classes创造自己的数据集

回到这里来……

对于加载已被集成在Pytoch内的数据集,我们以CIFAR为例(通过下面程序进行加载? ? ? ? ??torchvision.datasets.CIFAR10):

?对于其中参数如下:

那么,torchvision.datasets.CIFAR10?加载数据集之后,返回的是什么呢??

?

?这里的关于这条函数的源文件:SOURCE CODE FOR TORCHVISION.DATASETS.CIFAR

因此,加载内置数据集的例程如下:

transform_train = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

cifar10_training = torchvision.datasets.CIFAR100(root='./data', 
                                                  train=True,
												  download=True, 
                                                  transform=transform_train)

cifar10_training_loader = DataLoader(cifar100_training, 
                                      shuffle=shuffle, 
                                      num_workers=num_workers, 
                                      batch_size=batch_size)

?接下来看看transforms的具体用法

2、torchvision.transforms和torchvision.transforms.Compose:图像预处理

Pytorch中官方文档对torchvision.transforms的解释:torchvision.transforms — Torchvision 0.10.0 documentation (pytorch.org)

?transforms其实也是torchvision当中作为图像转换的一个包,图像可以通过函数compose连接在一起。大多的图像转换都是可以同时接收PIL图像和张量图像的,也可以单独接收PIL图像或者张量图像的。

torchvision.transforms:常用的数据预处理方法,提升泛化能力,包括:数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换等。数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。

不仅如此,除了接收单独的图像之外也可以接收处理成批量的图片,也就是我们在神经网络中经常遇到的batch。一张张量图像的大小表示为(C,H,W)--->(channel, height, wight),但是a batch of 张量图像则表示为(B,C,H,W)--->(batch, channel, height, wight)

同时,图像的数据应当进行归一化

如果需要转换的参数有很多,那就需要通过torchvision.transforms.Compose(transforms)将所有的转换类型组合到一起。但是此类转换,并不支持torchscript。

(torchscript语言自身是python语言的一个子类,目的是将Pytorch模型转化成torchscript好方便在C++的环境中进行调用模型,如需要进行转换,这要使用torch.nn.Sequential函数,这里就不详细说明了,可点击官方文档进行参考)

我们看看torchvision.transforms.Compose(transforms)的使用方法:

?那常见的transforms有哪些呢?哪些在我们的项目中需要用到的呢?

  • torchvision.transforms.ToTensor()

? ? ? ??将PIL或者numpy.ndarray转化为张量,但是有一个需要注意的是!!!对于这一条函数而言,如果,PIL图像的模式属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)中的,或者是将PIL Image 或者numpy.ndarray的数据类型是dtype = np.uint8的,那就:

numpy.ndarray (H x W x C) in the range [0, 255]

转化为? ??torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]

  • torchvision.transforms.Normalize(mean,?std,?inplace=False)?

????????使用均值和标准差对张量图像(不支持PIL图像)进行归一化,对n个通道都进行归一化,那么output[channel]?=?(input[channel]?-?mean[channel])?/?std[channel]

关于计算图像的均值和方差,这篇文章写的不错,如果想要了解,请看这篇文章,推荐:《PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差》,这里摘录一些重点:

?

?其他的一些预处理的操作可以参考官方文档以及这篇博客:《pytorch使用——(五)transforms详解》

3、torch.utils.data.DataLoader和torch.utils.data.Dataset

我们一点一点来分析……torch.utils.data这个Python API

Pytorch中官方文档对torch.utils.data的解释:torch.utils.data — PyTorch 1.9.0 documentation

在官方文档中,对于这个API的第一句话就提到了torch.utils.data.DataLoader,说它是Pytorch的核心数据加载工具,不仅提供数据集可以进行python迭代,同时也支持(这些功能在官方文档中有详细的说明,大致的意思是数据集加载可自定义修改且方便迭代,同时可以成批加载单类或多分类的数据,且有内存记忆的功能)

具体的含义可以通过?DataLoader的参数体现出来

  • torch.utils.data.DataLoader

????????主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用,比较重要

  • dataset?(Dataset) – 数据读取接口(比如torchvision.datasets.ImageFolde),该输出是torch.utils.data.Dataset类的对象(或者继承该类的自定义的对象)

  • batch_size?(int,?optional) – 批训练数据量的大小,一般为2的指数:32/64/128……(default:?1).

  • shuffle?(bool,?optional) –打乱数据集,一般在训练数据时采用?(default:?False).

  • sampler?(Sampler?or?Iterable,?optional) – 定义从数据集中提取样本的策略. Can be any?Iterable?with?__len__?implemented. If specified,?shuffle?must not be specified.

  • batch_sampler?(Sampler?or?Iterable,?optional) – like?sampler, but returns a batch of indices at a time. 一般同batch_size,?shuffle,?sampler, and?drop_last等参数互斥,默认即可

  • num_workers?(int,?optional) – 加载数据集所用的进程数,可理解为参与数据提取的CPU核数,加快数据提取的速度?(default:?0)

  • collate_fn?(callable,?optional) –?合并样本清单以形成小批量。用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。

  • pin_memory?(bool,?optional) – If?True,数据加载器将把张量复制到CUDA内存中,然后返回它们。也就是一个数据拷贝的问题。

  • drop_last?(bool,?optional) – 如果数据集大小不能被批大小整除,则设置为True,以删除最后一个不完整的批。如果数据集的大小不能被批处理的大小整除,那么最后的批处理会更小。(default:?False)

  • timeout?(numeric,?optional) – 设置一个正数表示数据读取超时 (default:?0)

  • worker_init_fn?(callable,?optional) – If not?None, this will be called on each worker subprocess with the worker id (an int in?[0,?num_workers?-?1]) as input, after seeding and before data loading. (default:?None)

  • generator?(torch.Generator,?optional) – If not?None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate?base_seed?for workers. (default:?None)

  • prefetch_factor?(int,?optional,?keyword-only arg) – Number of samples loaded in advance by each worker.?2?means there will be a total of 2 * num_workers samples prefetched across all workers. (default:?2)

  • persistent_workers?(bool,?optional) – If?True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers?Dataset?instances alive. (default:?False)

如果想要对DataLoader(其实就是实现DataLoader的功能)了解更多,推荐文章:《PyTorch—torch.utils.data.DataLoader 数据加载类》

  • torch.utils.data.Dataset

? ? ? ? 提供数据集的一个抽象的类,当我们需要用到自定义的数据集时,可以去继承Dataset类并覆盖__len__()和__getitem__()方法,其中__len__()返回数据集的样本个数,getitem(index)返回训练集的第index个样本

二、加载自定义数据集

根据官网的说明,Base classes for custom datasets(自定义数据集):torchvision.datasets — Torchvision 0.10.0 documentation (pytorch.org)

一共有两类:

  • torchvision.datasets.DatasetFolder
  • torchvision.datasets.ImageFolder

1、torchvision.datasets.DatasetFolder

?DatasetFolder具体的参数是:

  • root?(string) – 根目录路径,比如root目录下包含cat和dog两个文件夹(两个类)

  • loader?(callable) – 一种函数,可以由给定的路径加载图片

  • extensions?(tuple[string]) – 允许的扩展列表。扩展名和有效文件都不应该被传递

  • transform?(callable,?optional) – 一种函数或转换,前面提到过?transform,E.g,?transforms.RandomCrop?for images.

  • target_transform?(callable,?optional) – 同前

  • is_valid_file?–检查文件是否有误

?

??DatasetFolder这是一个通用的数据加载器,通过重写find_classes()这个方法可以获得文件的目录结构(这个就是说在py文件里面你可以重写这个方法,这样可以灵活的应用数据集的文件)

??

但是一般文件目录需要遵守下面的结构框架:

?

?这里的dictionary指的是包含分类(dog/cat/caw……)文件夹的根目录文件(如上图)相当于前面的self.root,如果找不到相应的目录文件,就会返回错误FileNotFoundError,如果准确无误,那么返回:一个元组? 各种分类的列表,以及对应各种分类对应索引的字典

(Tuple[List[str], Dict[str,?int]])

假设我的目录结构为:

?find_classes(root)源码如下:

该函数的参数:dir = root,调用DatasetFolder类时使用的目录,多为travel,test等

def find_classes(dir):
#这里的dir则表示为".../root"
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]   
    # 遍历dir目录下的所有子目录名称(dog和cat)并将其存在classes中
    # classes = [dog,cat]
    classes.sort()
    # 由于Python版本的不同可能需要更换为sorted(classes)
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    # class_to_idx就是将类别转化为数字表示,class_to_idx = {dog:0;cat:1}
    # 创建一个字典,将类别与数字对应
    return classes, class_to_idx
    # classes = [dog,cat]
    # class_to_idx = {dog:0;cat:1}

可以看到只是在读取文件的路径和形式有些许区别,在使用不同的数据集时需要我们自己去调整具体细节,最后可以输出classes, class_to_idx这样的形式即可<(^-^)>

2、torchvision.datasets.ImageFolder

实际上是继承DatasetFolder的,他们的方法都是一样的,同样的方法定义数据集

返回的dataset都有以下三种属性:

  • self.classes:用一个 list 保存类别名称
  • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
  • self.imgs:保存(img-path, class) tuple的 list

?推荐博客:

《pytorch学习笔记七:torchvision.datasets.ImageFolder使用详解》

程序案例:

from torchvision.datasets import ImageFolder
from torchvision import transforms

#加上transforms
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
    transforms.RandomCrop(180),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    normalize
])

dataset=ImageFolder('./data/train',transform=transform)

得到的dataset,它的结构就是[(img_data,class_id),(img_data,class_id),…],下面我们打印第一个元素:【(图像tensor数据,标签),(图像tensor数据,标签),……】

?再看一下dataset的三个属性:

三、加载本地kaggle猫狗数据集

文件目录

└─kaggle
? ? ├─test
? ? └─train

在train文件夹下:

在test文件夹下:

?仔细看小猫小狗,可以发现它们姿态不一,有的站着,有的眯着眼睛,有的甚至和其他可识别物体比如桶、人混在一起。同时,小猫们的图片尺寸也不一致,有的是竖放的长方形,有的是横放的长方形,但我们最终需要是合理尺寸的图片。所以需要进行图片处理,并把图片转化成Tensor作为模型的输入。(参考推荐博客:《pytorch实现kaggle猫狗识别(超详细)》

data_transform = transforms.Compose([
    transforms.Resize(84),
    transforms.CenterCrop(84),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root='E:/AI学习/数据集/图像分类数据集/kaggle猫狗数据集/kaggle/train/', transform=data_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True,
                                           num_workers=num_workers)

在下一篇博客中,我们会记录:

(!详解 ?Pytorch实战:②)kaggle猫狗数据集二分类:……

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-25 12:12:14  更:2021-08-25 12:13:37 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 21:43:45-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码