IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> PyTorch Dataloader源码分析(一) -> 正文阅读

[人工智能]PyTorch Dataloader源码分析(一)

torch.utils.data.Dataloader是PyTorch数据加载工具的核心类。在网络脚本中使用流程一般如下:

train_loader = torch.utils.data.Dataloader(...)
for input, target in train_loader:
    # 前向计算
    output = model(input)
    # 计算损失
    loss = loss_fn(output, target)
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    # 梯度更新
    optimizer.step()

从使用方式可以看出,Dataloder本质上是将数据抽象成可迭代的python对象使用,除此之外还支持:

  • map风格和iterable风格的数据集
  • 自定义数据加载顺序
  • 自动批处理
  • 单/多进程数据加载
  • 自动内存锁页

上述功能选项在Dataloder构造参数中配置:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

下面分为几个章节分别探究Dataloder的代码结构和上述功能的具体实现。

一、Dataloader整体架构

根据上述Dataloader的使用方式来看,Dataloader应该是一个iterrable(可迭代对象),内部需要维护一个iterator(迭代器)。python中for循环访问可迭代对象的内部流程如下:

  1. 调用可迭代对象的__iter__方法,拿到迭代器对象
  2. 调用迭代器的__next__方法,遍历内部数据
  3. 循环步骤2,直到迭代器内部数据流访问完成,捕获异常

对应下图,for循环先通过Dataloader类的__iter__方法拿到迭代器,即_SingleProcessDataLoaderIter类或_MultiProcessDataLoaderIter类,之后每次循环都会调用迭代器的__next__方法获取input和target数据(具体如何获得数据后面会介绍),直到全部数据访问完退出for循环。
Dataloader
Dataloader类的工作比较简单,对用户参数做个检查,再创建出xxxDataLoaderIter迭代器需要的一些类(例如Sampler和BatchSampler)以及xxxDataLoaderIter迭代器,剩下对数据集的访问工作就全权委托xxxDataLoaderIter迭代器去做了。
在具体介绍xxxDataLoaderIter迭代器之前,先简单了解下其依赖的组件及大致工作流程。
xxxDataLoaderIter主要用到的组件有:

  • Dataset类
  • Sampler && BatchSampler类
  • Fetcher类
  • collate_fn函数
  • pin_memory函数

每次for循环调用next(xxxDataLoaderIter),都由上述组件相互配合完成对具体数据的访问。大致流程如下:
xxxDataLoaderIter
BatchSampler类每次产生一堆index下标,Fetcher类通过index下标将数据从Dataset类中取出来,然后通过collate_fn函数将取出来的数据整理成Tensor,如果开启了pin memory功能,还会将对应的pageble Tensor转换成pinned Tensor然后输出。
以上就是xxxDataLoaderIter迭代器的大致工作流程,至于每个组件的功能实现以及具体的_SingleProcessDataLoaderIter迭代器和_MultiProcessDataLoaderIter迭代器的工作原理,请看后续章节分析。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-28 15:28:52  更:2022-02-28 15:29:02 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 3:26:15-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码