IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch dataloader 和 dataset 数据加载的研究 -> 正文阅读

[人工智能]pytorch dataloader 和 dataset 数据加载的研究

一 pytorch 数据加载的研究

一、dataloader and dataset?

Dataset抽象类,所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。

DataLoader(): 迭代器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。

二、类的实例化

大多数文章,并没有仔细探究Dataset这个类,究竟是怎么一步步完成数据和标签的加载的

first of all ,它是个类
所以,从类的角度,继承,重写,实例化,这个面向对象的思路,先研究一下

1.继承Dataset

代码如下(示例):xxx代表可以自己定义的内容

class myDataset(Dataset):
    def __init__(self, xxx):
     
    def __getitem__(self,index):
            return xxx,xxx
            
    def __len__(self):
        return len(xxx)

可见,getitem 和 len 需要自己重写,并返回一些东西

2.重写父类函数

这里,采用了kaggle 的Dog Breed Identification项目的数据
是个分类任务,使用resnet vgg 就可以解决
数据集包含 3个文件
train(文件夹)
test(文件夹)
label.csv
在这里插入图片描述
可以到官网看 https://www.kaggle.com/competitions/dog-breed-identification/
代码如下(示例):

from torch.utils.data import Dataset
import pandas as pd
import cv2
class myDataset(Dataset):
    def __init__(self, dogdir):
        self.imgset =  dogdir["id"]
        self.labelset = dogdir["breed"]
        dog_breeds = sorted(list(set(self.labelset )))
        n_classes = len(dog_breeds)
        self.class_to_num = dict(zip(dog_breeds, range(n_classes)))
        
    def __getitem__(self,index):
            imgpath = "train/"+self.imgset[index] + ".jpg"
            img = cv2.imread(imgpath)
            labelname  = self.labelset[index]
            labelhot =  self.class_to_num.get(labelname)
            return img, labelhot
    
    def __len__(self):
        return len(self.imgset)

3.实例化

看看继承后的dataset

df = pd.read_csv('labels.csv')  #使用pandas读取csv  
myd = myDataset(df)
img,label = myd.__getitem__(4) #指定4这个item
lenth = myd.__len__()
#print(img)
print(label)
print(lenth)

49 #one hot 编码后的标签
10222 # 总体数量

总结

提示:这里对文章进行总结:

例如:以上就是今天要讲的内容

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-24 09:26:48  更:2022-04-24 09:30:13 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 9:31:21-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码