| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> 数据过大时dataloader怎么设计? -> 正文阅读 |
|
[人工智能]数据过大时dataloader怎么设计? |
前言:最近笔者在跑项目的时候遇到一个场景:训练数据过大比如100G,那么是不可能全部加载到内存后训练的,那怎么办呢? 其实具体来说当数据过大时,其会导致两个问题 (a)加载数据时间过长 (b)时间长也就忍了,关键是内存会爆掉 怎么解决呢?可能想到的是边加载数据边训练,那如果我们的代码恰好又是多机多卡的呢?那我们的dataloader该怎么实现呢? 下面会涉及到一系列小碎点,我们逐步深入的看看怎么实现。 本篇涉及的知识较多,比较绕,建议收藏反复琢磨。一个基本能跑通的dataloader在文末已给出~ 加载数据快起来本节需要的是python的file.seek()和file.tell(),不熟悉的小伙伴,建议先查查相关资料,很简单~ 一般来说,我们在写dataloader的时候,是会把数据全部加载进来一遍的进而导致崩掉,这里通常采用的一个trick手段是:file.seek()和file.tell() file.seek()是会将文件指针移动到指定的地方,那么假设我们的训练数据train.txt有10000行,那么我们在写dataloader的时候就不需要将数据全部加载进来,而是通过如下:
注意看open("train.txt")一行基本上很快,而且不用占用很大内存,是通过seek不断移动到对应行取对应的数据,细心的读者可能已经发现这里有一个offsets列表,其实是因为f.seek是不能直接按行移动的,其只能按照字符移动,但是每一行的字符数又不一样,所以offsets列表就是记录每行之前字符数的,比如[45, 67, 89],那么以后就直接可以根据idx取出对应字符数,进而seek根据其可以直接定位到行的启始位置。 那么offsets(dataset_idx_path)怎么得到呢? 很简单
这里的data_file就是你的100G的大文件,"dataset_tmp.id"就是产生的dataset_idx_path,供后续加载。 说到这里,需要注意两点 (1)数据能离线预处理的就离线预处理,比如如果使用的是bert等预训练模型,那么最好离线处理成tokenizer id即data_file中的数据是模型直接可以使用的,而且上述"dataset_tmp.id"其实也是预先离线得到好的文件。 (2)线上dataloader加载数据训练的时候其实只需要花一点时间加载"dataset_tmp.id"和open(data_file),后者基本不花时间。而且全程没有加载很大的数据到内存,因为其是不断通过文件指针找对应数据的。 通过以上方式,我们就可以基本避免因为全部加载数据而奔溃掉的问题。是不是有点时间(文件指针不断找)换空间(内存不爆掉)的味道~~ 所以总的来说这块解决了内存爆掉已经加载过长的问题。 GPU利用率低本节需要的是python的multiprocessing包,不熟悉的小伙伴,建议先看看相关的教程,也不是很难,熟悉最基本的API即可,比如熟悉Process,Manager,Lock,Pool概念。 如果我们的数据很大的时候,使用上述方式,虽然可以基本跑起来了,但是会发现GPU的利用率还是不够高,或者说波动很大,那是为什么呢? 其实是GPU在等待数据从CPU传输过来,GPU处理速度很多,马上结束后,会等待下一个batch的传入。之所以要等是因为文件指针要花时间不断的找对应数据,这是需要换时间的。 所以我们为了空间,牺牲掉了时间,但是时间又影响了GPU利用率。 那怎么解决呢?多进程!!! 既然找的慢,那我们多开几个进程来找,这样不就快起来了吗?注意不是使用的多线程哦,多线程本质上还是时间切片,不能真真意义上的加速。 另外一个问题是怎么实现一边装数据一边训练呢?那就是创建一个缓冲buffer。一个进程是不断往buffer装数据,另外一个是不断取数据去训练,两则互不影响。也就是说你在训练的时候,我还在不停的往buffer里面装数据以至于你想用的时候可以直接取就行。 下面我们再具体来理一下这里的实现逻辑: 因为要用多线程,所以我们这里可以申请两个全局buff列表即multiprocessing.manager.list(),这里简单起名为 buff_single和buff_batch,下面我们来定义一些变量,比较复杂,大家可以花点时间理一下, 装数据的进程不断的seek到一个个样本数据后装到buff_single列表。因为训练的数据进程取数据的时候希望是一个batch一个batch的取,所以我们还需要将buffsingle列表的数据组装成一个一个batch形式放到buffer_batch列表中,这个函数就暂时叫做fill_batch吧。 假设我们buff_single的大小设置为25600即最多装25600个样本。同时我们在装数据的时候(本质上是上述第一节seek的逻辑)希望又是一个多进程,这里假设用5个进程吧即num_process=5。 这里简单提示一下:看到这里很多人可能有个疑问,装数据和取数据本身就是两个进程?那装数据又是多个进程?一边装一边取的关键trick是怎么实现的呢?下面我们来看: (1)我们首先来看取数据的逻辑
通常来说我们声明了一个dataloader后,采用迭代器来不断的取即
这里构造出来的train_dataloader迭代器,其实就是_iter_函数。 所以我们重点来看DataLoader这个类,可以看到_iter_函数是一个while True的逻辑,也就是说不断循环,而且不会停,即就是不断从self.buffer_batch中取batch,即使self.buffer_batch没有数据了,其还是会不断在这里循环等,直到有了数据后yield返回。这就是取数据的逻辑。很简单吧 那装数据的逻辑在那里呢?那就是_fill_buf这个函数,可以看到其实在第一次从train_dataloader这个迭代器中取数据的时候,也即第一次调用这个_iter_函数的时候,其实是先会进行_fill_buf这个函数的,它就是会开启多个进程不断的向self.buffer_batch装数据,特别需要注意,multiprocessing开启进程后,就是不断的自己去运行了,程序会接着往下走,也就是说_fill_buf相当于开启了后台进程,主程序继续往下走即while True这里。 上面这段话主要就是想解释怎么实现的一边装数据、一边取数据。大家可以多理解理解,说简单其实也很简单,就是_fill_buf在后台开启几个进程,往self.buffer_batch里面不断装数据。外部取数据就是不断从self.buffer_batch取,没有的话就是不停的while True进行循环直到self.buffer_batch里面装了数据就行,不需要多,只有有一个数据装进来就可以啦。 (2)接着我们来看怎么实现多进程装数据
可以看到我们这里就是很简单的开启了多个进程,即开启了self.num_process个进程运行self.buf_thread程序,然后就是start,其就在后台不断的运行啦,为什么说不断?是的就是不断!self.buf_thread函数是不会停的,其就是不断的装数据,一遍又一遍的循环数据,即达到一个epoch后,再重新进行取新一轮的epoch。 这里再理一下这里的逻辑: 大家可以看到不论是取数据还是装数据,其实内部都是一个while True的逻辑即不断循环,进而实现不断的取不断的装的逻辑,这里弱化了epoch的概念,是不断的取数据流,重复一遍一遍的取,那程序整个结束的标志在哪里呢?其实是在外面的,即在训练流程(看上面):
所以看的是全程是step来控制,没有epoch,当然了有了step自己推一下epoch是很轻松的事。 好啦,言归正传,现在回到正题,目前最重要的应该是self.buf_thread这个函数了,其的作用就是不断的往self.buffer_batch里面装batch数据,当然了self.buffer_batch是一个全局变量,因为我们是开启了多个进程同时都往self.buffer_batch里面装数据的。下面我们看看buf_thread的实现
首先可以看到有两个while True,最外面那个确保的是不断的一遍一遍的取整个数据集。最里面那个就是一个个取装到buffer,首先我们取一个个的是先装到self.buffer_single,当装慢时(>=self.instances_buffer_size),就开始组batch,即调用的是_fill_batch()这个函数(上面说过),说白了起就是将self.buffer_single里面的单个样本组合成batch的形式放到另一个buffer中即self.buffer_batch供程序取数据进行训练。 需要关注的是:
这里本质上是为了确保不同进程装不同的数据,进而避免装重复了。
同时需要注意一点的是文件的读取等都要放到buf_thread子函数中,比如不能把 f_read
作为一个类变量,因为多进程是可以抢的,假设当成一个类全局变量,那么可能遇到这种情况:进程1刚将 f_read seek到自己想要的位置,下一步马上就要f_read.readline()了,结果另外一个进程抢先又seek了一下,所以最后结果就是乱的,所以每个进程都自己open一个数据流f_read避免乱。 下面我们再看组bacth的函数_fill_batch
可以看到_fill_batch其实就是随机取self.buffer_single这个buffer中的一批数据来做batch放到self.buffer_batch中,并且把对应的用掉的数据从self.buffer_single中pop出来,小插曲:为啥这里要用sort呢?其实是由于python的pop函数导致的,我们是按index从大到小pop,这样就不会乱,不然pop的过程中,self.buffer_single的大小是变化的,不从大到小确保不了pop出对应位置的数据。 注意这里使用了multiprocessing的Lock,因为所有进程都是可能进到_fill_batch这个函数进行操作的,那大家都可能会同时pop self.buffer_single这个数据流,那 self.buffer_single的长度都是变化的,进而导致都是乱的,你pop一下,另外一个进程冷不丁的pop一下,所以这里用了Lock即同一时间只能有一个进程进行这里的操作。 当然了对于往self.buffer_single和 self.buffer_batch装数据这一过程(append)都没使用Lock,那是因为不需要,大家同时合力装就可以啦,你一下我一下,顺序无所谓。 多机多卡本节需要的知识是多机多卡的基本概念,比如:local_rank、world_size、rank,不需要的小伙伴可以先查阅一下相关知识。 我们同时还想在单机多卡上面,甚至多机多卡上面跑,那这里又有什么变化呢? 试想使用了多卡后,对于dataloder我们本质要解决的问题是什么呢? 那就是不同卡应该加载不同的数据,之间不能有重复,说白了,如果我们可以把数据分段,假设现在一共8张卡,那我们把数据分成8段,只让对应卡加载对应的数据段上面的数据不就可以啦。 那怎么实现呢?非常非常简单 这里我们来说最复杂的场景:多机多卡 那在多机多卡的场景中每一张卡的全局ID可以使用rank,那每一张卡对应的自己数据片段的全部数据就应该是:
即在上面dataloder中补上这个逻辑就可以啦。 当然了pytorch也提供了torch.utils.data.distributed.DistributedSampler这个API,其会自动取对应自己数据片段的index,而不会去取别人的数据,本质上和上面的实现逻辑一样。如果懒得写,可以直接使用torch.utils.data.distributed.DistributedSampler。 那如果是单机多卡呢?没有rank和world_size,其只有一个local_rank参数。其实更好办了,我们只需要理解了其实我们需要两个参数,一个是当前卡的ID,另外一个是总卡数,那么就可以通过如下代码实现唯一切片数据
所以对应到单机多卡便可以这样写:
dev dataloader怎么设计?最后再说一个大家可能遇到的问题,就是在训练的时候,有可能需要在固定步数看一下验证集上的效果,那必然需要一个dev dataloader,可能有人会说,这有啥,就用上面的就行,可是这里会有一些问题,大家可以停几分钟想一下直接用上面的会有什么问题? ... 好,下面我们来说一下,不知道大家注意到没有,之前我们也说了,上面的dataloader是一直处于while True的逻辑的,不会停,循环装数据,而我们的dev dataloader是希望过一遍dev数据即可,所以适当的stop就尤为重要。 这时候可能有的小伙伴会说,计算出对应的最后一个数据stop就行啦,但是这里涉及到多进程,要等所有进程结束,所以这里又是一个比较复杂的逻辑实现,注意train dataloader其实通过弱化epoch,使用step不断取的过程简化了该步骤。 当然了这不是什么难事,大家看buf_thread函数最里面while True这个函数的break,其实就是一个epoch的结束。 这里将train dataloder和dev dataloader统一写出一个dataloder,为了简单,我们dev dataloader可以写成之前的最基本的dataloader形式,主要基于以下两个原因: (a)通常来说dev数据不大,不太需要多进程以及边加载边取数据 (b)在多卡训练场景,dev数据可以只在一张固定的卡上进行(比如卡0),所以也不会涉及到多机多卡一节中说到的数据切片,直接在一张卡上过所有数据即可。 总结看完后最好是将其逻辑应用到自己的当前dataloder,代码不一定要按上面写,可以结合自己的场景随意改动,甚至可以用到tensorflow中,所以理解其逻辑最重要,如果能跑通,恭喜你,说明理解了所有逻辑。 最后给一份笔者目前整理的一个能跑通的dataloder,欢迎star~,后续有时间会持续优化一下相关代码支持更多的逻辑比如train dataloder不想用边装边取,只想普通的等等。 彩蛋文章开头我们介绍了要尽可能的将我们的数据预处理好,那当我们的数据非常大的时候,势必要用到一些大数据工具比如spark,haddop等等,下面一篇我们会简单介绍一下pyspark基本dmeo,供大家快速上手~ 关注欢迎关注,下期再见啦~ 欢迎关注笔者微信公众号: github: https://github.com/Mryangkaitong?github.com 知乎: |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 | -2025/1/10 16:12:40- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |