IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 《Pytorch学习指南》- Dataset和Dataloader用法详解 -> 正文阅读

[人工智能]《Pytorch学习指南》- Dataset和Dataloader用法详解

前言

本章节主要介绍如何使用torch.utils.data 中的Dataset和Dataloader来构建数据集, 重点要看使用细节

DataSet

  • torch.utils.data.Dataset
    • 功能 : Dataset抽象;类, 所有自定义的Dataset都需要继承他, 并重写相应的方法
    • getitem(self, index)
      1. 接收一个索引, 返回一个样本 : index => label, data
      2. 返回的样本的大小要一样

DataLoader

  • torch.utils.data.DataLoader
    • 功能 : 创建可以迭代的数据装载器
    • 参数 :
      1. dataset : Dataset类对象, 决定数据从哪读取以及如何读取
      2. batchsize: 决定数据批次大小
      3. num_works: 多进程读取数据的线程数
      4. shuffle: 每个 epoch 是否乱序
      5. 当样本数不能被batchsize整除时, 是否舍去最后一个batch的数据
    • 名词解释 :
      1. 样本总数 : 80, batchsize : 8 => 1 Epoch = 10 iteration

数据构建

1. 创建Dataset 类 ?

class WeiBoDataset(Dataset):
	pass

2. 读取数据 🚑

注意 : 我们一般会在初始化的时候就加载进数据, 读取数据函数需要自定义

class WeiBoDataset(Dataset):

    def __init__(self, data_path):
        # 读取数据
        self.label, self.data = self.read_data(data_path)

3. 返回数据 ??

  • 这里需要注意的是, len 是必须要设置的, 返回的是你数据集的大小
  • 根据返回的len来构建索引, 然后把构建好的索引传入__getitem__里
  • getitem 根据传进来的索引获取对应的数据, 可以在这个方法里对数据进行处理
class WeiBoDataset(Dataset):

    def __init__(self, data_path):
        # 读取数据
        self.label, self.data = self.read_data(data_path)

    def __len__(self):
        """
            这个必须要设置, getitem中的index就是根据这个来设置的
        :return:
        """
        return len(self.data)

    def __getitem__(self, index):
        label = 1
        # features = [str(i) for i in range(10)]
        features = np.array([i for i in range(10)])
        return label, features

读取数据 🎨

weibo_dataset=WeiBoDataset("../../datasets/weibo_test_data.csv)
dataloader=DataLoader(weibo_dataset,batch_size=1024,shuffle=True)
for i, batch in enumerate(dataloader):
	# batch : [label, features] 组成
    print(type(batch[0]), type(batch[1]))

注意细节 🚀

  1. 先获取数据集的大小 len
  2. 根据len生成index, 然后shuffle
  3. 根据shuffle后的数据以及batch_size生成索引列表batch_index, 索引列表的大小为 batch_size
  4. 获取每个batch的数据时, 根据batch_index传入到 getitem 获取对应的数据
  5. 注意 : batch的数据类型取决于__getitem__返回的类型, 一般都会转换为tensor
  6. 有的数据类型是无法转换为tensor的, 比如 元素类型为str的list
  7. default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists found
  8. 上面报错原因就是 因为数据无法转换为 tensor , 而类型又不属于 tensors, numpy arrays, numbers, dicts or lists 这几种
  9. 如果返回的数据是集合类型, 可以直接使用 np.array() 转换为ndarray类型, 这样会被自动转换为tensor, 当然要求这个集合类型的元素类型是tensor有的
  10. 如果是tensor没有的,比如 str 类型的, 反而会报错, 比如 7. 报错

对比实验

注意 features的元素类型是str, 那么可以看到下面的输出结果中 label 是 tensor, features 是 list类型的

def __getitem__(self, index):
	label = 1
	# 转换为 ndarray 会报错
	# features = np.array([str(i) for i in range(10)]) 
    features = [str(i) for i in range(10)]
    return label, features
<class 'torch.Tensor'> <class 'list'>
<class 'torch.Tensor'> <class 'list'>
<class 'torch.Tensor'> <class 'list'>
<class 'torch.Tensor'> <class 'list'>

下面将feature中的数据元素换成了int类型的, 并且对将list转换为ndarray, 这样在获取batch时数据会自动转换为tensor , 但是这里需要注意的是, 上面的数据是不能用np.array()的, 这是因为 batch 必须包含 tensors, numpy arrays, numbers, dicts or lists 这几种类型, 其他的都会报错, 具体可以查看

def __getitem__(self, index):
	label = 1
    features = np.array([i for i in range(10)])
    return label, features
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-28 12:23:44  更:2021-10-28 12:25:11 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 8:47:46-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码