IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【pytorch】|Dataloader dataset sampler torchvision -> 正文阅读

[人工智能]【pytorch】|Dataloader dataset sampler torchvision

pytorch 数据读取机制

PyTorch中对于数据集的处理有三个非常重要的类:Dataset、Dataloader、Sampler,它们均是 torch.utils.data 包下的模块(类)。

torch/utils/data下面一共含有4个主文件

|---- dataloader.py
|---- dataset.py
|---- distributed.py
|---- sample.py

pytorch 的数据加载到模型的操作顺序是这样的:

① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,通过dataset、sampler参数将img, label加载到模型中进行训练

在这里插入图片描述

dataloader \dataset \sampler概述

dataloader \dataset \sampler三者关系

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • Dataset是数据集的类,主要用于定义数据集
  • Sampler是采样器的类,用于定义从数据集中选出数据的规则,比如是随机取数据还是按照顺序取等等
  • Dataloader是数据的加载类,Dataset和Sampler会作为参数传递给Dataloader,是对于Dataset和Sampler的进一步包装,用于实际读取数据,可以理解为它是这个工作的真正实践者,而Dataset和Sampler则负责定义。我们训练、测试所获得的数据也是Dataloader直接给我们的。

torch.utils.data Dataset

我们通过定义继承自这个类的子类来自定义数据集。它有两个最重要的方法需要重写,实际上它们都是类的特殊方法:

getitem(self, index):传入参数index为下标,返回数据集中对应下标的数据组(数据和标签)
len(self):返回数据集的大小
Dataset的目标是根据你输入的索引输出对应的image和label,而且这个功能是要在__getitem__()函数中完成的,所以当你自定义数据集的时候,首先要继承Dataset类,还要复写__getitem__()函数。

torch.utils.data DataLoader

Dataloader对Dataset(和Sampler等)打包,完成最后对数据的读取的执行工作,一般不需要自己定义或者重写一个Dataloader的类(或子类),直接使用即可,通过传入参数定制Dataloader,定制化的功能应该在Dataset(和Sampler等)中完成了。

from torch.utils.data import DataLoader

 def __init__(self, dataset, 
 			batch_size=1, 
 			shuffle=False, 
 			sampler=None,
            batch_sampler=None, 
            num_workers=0, 
            collate_fn=None,
            pin_memory=False, 
            drop_last=False, 
            timeout=0,
            worker_init_fn=None, 
            multiprocessing_context=None):

from torch.utils.data import DataLoader

DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…)
该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

dataset(Dataset): 传入的数据集

  • batch_size(int, optional): 每个batch有多少个样本

  • shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序

  • sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

  • batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

  • num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

  • pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

  • collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

  • drop_last (bool, optional): 这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,如果设置为True,那么训练的时候后面的36个就被扔掉了…
    如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

  • timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

  • worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each
    worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

torch.utils.data.sampler

from torch.utils.data.sampler import Sampler

Sampler类是一个很抽象的父类,其主要用于设置从一个序列中返回样本的规则,即采样的规则。Sampler是一个可迭代对象,使用step方法可以返回下一个迭代后的结果,因此其主要的类方法就是 iter 方法,定义了迭代后返回的内容。

但是,Dataloader中的sampler和batch_sampler参数默认情况下使用的那些采样器(RandomSampler、SequentialSampler和BatchSampler)一样,PyTorch自己实现了很多Sampler的子类,这些采样器其实可以完成大部分功能,所以基本只需要关注一些Sampler的子类以及他们的用法。

SequentialSampler

SequentialSampler就是一个按照顺序进行采样的采样器,接收一个数据集做参数(实际上任何可迭代对象都可),按照顺序对其进行采样。

from torch.utils.data import SequentialSampler

pseudo_dataset = list(range(10))
for data in SequentialSampler(pseudo_dataset):
    print(data, end=" ")

0 1 2 3 4 5 6 7 8 9 

RandmSampler

RandomSampler 即一个随机采样器,返回随机采样的值,第一个参数依然是一个数据集(或可迭代对象)。还有一组参数如下:

  • replacement:bool值,默认是False,设置为True时表示可以采出重复的样本
  • num_samples:只有在replacement设置为True的时候才能设置此参数,表示要采出样本的个数,默认为数据集的总长度。有时候由于replacement置True的原因导致重复数据被采样,导致有些数据被采不到,所以往往会设置一个比较大的值.
from torch.utils.data import RandomSampler

pseudo_dataset = list(range(10))

randomSampler1 = RandomSampler(pseudo_dataset)
randomSampler2 = RandomSampler(pseudo_dataset, replacement=True, num_samples=20)

print("for random sampler #1: ")
for data in randomSampler1:
    print(data, end=" ")

print("\n\nfor random sampler #2: ")
for data in randomSampler2:
    print(data, end=" ")

for random sampler #1: 
4 5 2 9 3 0 6 8 7 1 

for random sampler #2: 
4 9 0 6 9 3 1 6 1 8 5 0 2 7 2 8 6 4 0 6 

SubsetRandomSampler

SubsetRandomSampler可以设置子集的随机采样,多用于将数据集分成多个集合,比如训练集和验证集的时候使用:

from torch.utils.data import SubsetRandomSampler

pseudo_dataset = list(range(10))

subRandomSampler1 = SubsetRandomSampler(pseudo_dataset[:7])
subRandomSampler2 = SubsetRandomSampler(pseudo_dataset[7:])

print("for subset random sampler #1: ")
for data in subRandomSampler1:
    print(data, end=" ")

print("\n\nfor subset random sampler #2: ")
for data in subRandomSampler2:
    print(data, end=" ")

for subset random sampler #1: 
0 4 6 5 3 2 1 

for subset random sampler #2: 
7 8 9 

BatchSampler

以上的四个Sampler在每次迭代都只返回一个索引,而BatchSampler的作用是对上述这类返回一个索引的采样器进行包装,按照设定的batch size返回一组索引,因其他的参数和上述的有些不同:

  • sampler:一个Sampler对象(或者一个可迭代对象)
  • batch_size:batch的大小
  • drop_last:是否丢弃最后一个可能不足batch size大小的数据
from torch.utils.data import BatchSampler
pseudo_dataset = list(range(10))

batchSampler1 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=False)
batchSampler2 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=True)

print("for batch sampler #1: ")
for data in batchSampler1:
    print(data, end=" ")

print("\n\nfor batch sampler #2: ")
for data in batchSampler2:
    print(data, end=" ")

for batch sampler #1: 
[0, 1, 2] [3, 4, 5] [6, 7, 8] [9] 

for batch sampler #2: 
[0, 1, 2] [3, 4, 5] [6, 7, 8] 

torchvision

torchvision 是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程,也会让你安装上这个包。

这个包中有四个大类。

torchvision.datasets

torchvision.models

torchvision.transforms

torchvision.utils

torchvision.datasets

import torchvision.datasets as datasets

torchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集。


__all__ = ('LSUN', 'LSUNClass',
           'ImageFolder', 'DatasetFolder', 'FakeData',
           'CocoCaptions', 'CocoDetection',
           'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
           'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
           'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
           'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
           'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
           'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101',
           'Places365')
import torchvision

trainset= torchvision.dataset.MNIST(root='./data'
															train=True,
															transform=None)
train_Loader=Dataloader(dataset=trainset,
										batch_size=32,
										shuffle=False)
print("训练集大小",len(trainset))
print("训练集批次",len(train_loader))
															

torchvision.models

torchvision.models 中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用。

torchvision.models模块的 子模块中包含以下模型结构。

from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
from .googlenet import *
from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import *
from . import segmentation
from . import detection
from . import video
from . import quantization

快速创建一个权重随机初始化的模型

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

也可以通过使用 pretrained=True 来加载一个别人预训练好的模型

import torchvision.models as models
resnet18=models.resnet18(pretrained = True)
print(resnet18)

torchvision.transforms

transforms 模块提供了一般的图像转换操作类。

class torchvision.transforms.ToTensor 

把shape=(H x W x C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray转换成shape=(C x H x W)的像素值范围为[0.0, 1.0]的torch.FloatTensor。

class torchvision.transforms.Normalize(mean, std) 

给定均值(R, G, B)和标准差(R, G, B),用公式channel = (channel - mean) / std进行规范化。

# 我们这里还是对MNIST进行处理,初始的MNIST是 28 * 28,我们把它处理成 96 * 96 的torch.Tensor的格式
from torchvision import transforms as transforms
import torchvision
from torch.utils.data import DataLoader

# 图像预处理步骤
transform = transforms.Compose([
    transforms.Resize(96), # 缩放到 96 * 96 大小
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])

DOWNLOAD = True
BATCH_SIZE = 32

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=DOWNLOAD)


train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)

print(len(train_dataset))
print(len(train_loader))

ref
https://github.com/LianHaiMiao/pytorch-lesson-zh/
https://zhuanlan.zhihu.com/p/91521705
https://blog.csdn.net/g11d111/article/details/81504637
https://zhuanlan.zhihu.com/p/400830261

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-22 20:35:19  更:2022-03-22 20:40:19 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 1:53:31-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码