IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 关于pytorch读取数据集的一些知识点 -> 正文阅读

[人工智能]关于pytorch读取数据集的一些知识点

torchvision包提供了一些常用的数据集和转换函数,使用torchvision甚至不需要自己写处理函数。

一、对于torchvision提供的数据集

对于这一类数据集,PyTorch已经帮我们做好了所有的事情,连数据源都不需要自己下载。
Imagenet,CIFAR10,MNIST等等PyTorch都提供了数据加载的功能,所以可以先看看你要用的数据集是不是这种情况。

比如,加载MNIST数据集:

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    ),
    batch_size=TRAIN_BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data', train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    ),
    batch_size=TEST_BATCH_SIZE, shuffle=False)

二、对于特定结构的数据集

通过torchvision中的通用数据集ImageFolder来完成加载,它假设数据结构为如下:

root/airport/airplane(1).jpg
root/airport/airplane(2).jpg
root/airport/airplane(3).jpg
.
.
.
root/beach/beach(1).jpg
root/beach/beach(2).jpg
root/beach/beach(3)_.jpg

同样

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 数据增强
train_transforms = transforms.Compose(
        [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),  # 随机裁剪到256*256
        transforms.RandomRotation(degrees=45),  # 随机旋转
        transforms.RandomHorizontalFlip(),      # 随机水平翻转
        transforms.CenterCrop(size=224),        # 中心裁剪到224*224
        transforms.ToTensor(),                  # 转化成张量
        transforms.Normalize([0.485, 0.456, 0.406],  # 归一化
                             [0.229, 0.224, 0.225])
])

test_valid_transforms = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406],
                              [0.229, 0.224, 0.225])
])

# 利用Dataloader加载数据
train_directory = config.TRAIN_DATASET_DIR
valid_directory = config.VALID_DATASET_DIR

batch_size = config.BATCH_SIZE
num_classes = config.NUM_CLASSES

train_datasets = datasets.ImageFolder(train_directory, transform=train_transforms)
train_data_size = len(train_datasets)
train_data = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)  # shuffle将序列的所有元素随机排序

三、对于普通数据集

定义数据集的类MyDataset,这个类要继承Dataset这个抽象类,然后重写下面的函数:

①__len__: 使得len(dataset)返回数据集的大小;

②__getitem__:使得支持dataset[i]能够返回第i个数据样本的下标操作

通常情况还包括初始函数__init__.

# 读取图片,主要是通过Dataset类
# 通过继承torch.utils.data.Dataset的这个抽象类,可以定义我们需要的数据类
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, labels, transform=None):
        self.img_paths = img_paths
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):  # 实现索引数据集中的某一个数据
        img_path, label = self.img_paths[index], self.labels[index]

        img = cv2.imread(img_path)                  # 接口读取图片,读进来是BGR格式数据
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # 色彩空间转化函数cv2.cvtColor()进行色彩空间的转换,将BGR格式转换成RGB格式
        img = Image.fromarray(img)                  # numpy中的数组array转换成PIL中的image

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):               # 返回数据集的大小
        return len(self.img_paths)

train_set = MyDataset(
	train_img_paths,
	train_labels,
	transform=train_transform)

train_loader = torch.utils.data.DataLoader(
	train_set,
	batch_size=args.batch_size,
	shuffle=True,
	num_workers=4,
	sampler=None)

参考:链接: https://www.jianshu.com/p/6e22d21c84be.

待更新…

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-10 14:32:45  更:2021-07-10 14:33:49 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 14:39:53-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码