IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch加载自带数据集以及个人数据集的方式 -> 正文阅读

[人工智能]pytorch加载自带数据集以及个人数据集的方式

一、加载pytorch自带数据集

torchvison.datasets是torch.utils.data.Dataset的实现。
包括如下数据集:
all = (‘LSUN’, ‘LSUNClass’,
‘ImageFolder’, ‘DatasetFolder’, ‘FakeData’,
‘CocoCaptions’, ‘CocoDetection’,
‘CIFAR10’, ‘CIFAR100’, ‘EMNIST’, ‘FashionMNIST’, ‘QMNIST’,
‘MNIST’, ‘KMNIST’, ‘STL10’, ‘SVHN’, ‘PhotoTour’, ‘SEMEION’,
‘Omniglot’, ‘SBU’, ‘Flickr8k’, ‘Flickr30k’,
‘VOCSegmentation’, ‘VOCDetection’, ‘Cityscapes’, ‘ImageNet’,
‘Caltech101’, ‘Caltech256’, ‘CelebA’, ‘SBDataset’, ‘VisionDataset’,
‘USPS’, ‘Kinetics400’, ‘HMDB51’, ‘UCF101’, ‘Places365’)

1.使用torchvision.datasets加载数据集

import torch
import torchvision
from PIL import Image

cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)

2.使用torch.utils.data.DataLoader来实例化

cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)

3.测试

for i, data in enumerate(cifarLoader, 0):
    print(data[i][0])
    # PIL
    img = transforms.ToPILImage()(data[i][0])
    img.show()
    break

二、加载个人的数据集

1.继承Dataset类,生成数据集

import torch.utils.data as data
#定义myDataSet类来继承Dataset

#generate train_data or test_data...
def default_loader(path):
    return  Image.open(path).convert('RGB')

class myDataSet(data.Dataset):
    """"
    @:param
    label_txt:每个图像名称以及路径,one image one line
    """
    def __init__(self,label_txt,transform = None,target_transform = None, loader=default_loader):
        super(myDataSet, self).__init__()
        self.imgs = []
        self.transform =transform
        self.target_transform = target_transform
        self.loader =loader
        fn = open(label_txt,'r')
        imgs=[]
        for line in fn:
            line  = line.strip('\n')
            line = line.rstrip('\n')
            words = line.split()
            imgs.append(words[0])
        self.imgs = imgs

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        fn = self.img[index]
        img = self.loader(os.path.join(self.root,fn))
        return  img

label_txt的格式如下:
每一行是一个图像的绝对路径
同时,需要重写__len__与__getitem__两个函数如上
在这里插入图片描述

2.加载数据集

def get_my_data():
    train_data = myDataSet(label_txt='',transforms=transform.ToTensor())
    test_data = myDataSet(label_txt='', transforms=transform.ToTensor())
    train_loader = DataLoader(train_data,shuffle=True,batch_size=BATCH_SIZE,num_workers=1)
    #test_loader = DataLoader(test_data, shuffle=False, batch_size=BATCH_SIZE, num_workers=1)
    return train_loader

参考文献:
https://blog.csdn.net/sinat_42239797/article/details/90641659
https://zhuanlan.zhihu.com/p/27434001

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-16 22:21:38  更:2022-03-16 22:22:15 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 15:41:32-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码