IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch加载自己本地数据集 深度学习数据预处理 -> 正文阅读

[人工智能]pytorch加载自己本地数据集 深度学习数据预处理

数据集存放

将数据集分类存放,不同类别放在不同目录下,目录名即标签,数据集存放格式如下:
root/ants/001.png
root/ants/002.png
root/ants/003.png

root/bees/001.png
root/bees/002.png
root/bees/003.png

图像数据预处理

1、使用transforms设置图像预处理操作

设置剪裁、缩放、翻转等参数,详细参考链接transforms

# 设置图像数据预处理操作
data_transform = transforms.Compose([
 # 随机缩放剪裁 size 224*224
  transforms.RandomResizedCrop(224),
 # 依概率p水平翻转
  transforms.RandomHorizontalFlip(p=0.5),
 # 转为张量并归一化到[0,1](是将数据除以255),且会把H*W*C会变成C *H *W
  transforms.ToTensor(),
 # 数据归一化处理
  transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
 ])
2、使用 datasets.ImageFolder对指定位置图像进行如上预处理操作

root处写入本地数据集地址,对该地址的图像进行上文data_transform设置的图像预处理操作

# 对目标位置图像进行预处理
train_dataset = datasets.ImageFolder(root='Dataset/train', transform=data_transform)

图像数据集加载

1、使用torch.utils.data.DataLoader加载数据集

加载预处理好的数据集train_dataset,每批数据量大小为64

# 数据加载器 加载数据集
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64)

源码如下:

import torch
from torchvision import transforms, datasets

# 设置图像数据预处理操作
data_transform = transforms.Compose([
 # 随机缩放剪裁 size 224*224
  transforms.RandomResizedCrop(224),
 # 依概率p水平翻转
  transforms.RandomHorizontalFlip(p=0.5),
 # 转为张量并归一化到[0,1](是将数据除以255),且会把H*W*C会变成C *H *W
  transforms.ToTensor(),
 # 数据归一化处理
  transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
 ])

# 对目标位置图像进行预处理
train_dataset = datasets.ImageFolder(root='Dataset/train', transform=data_transform)

# 查看数据集长度
train_data_size = len(train_dataset)
print("训练数据集长度为:{}".format(train_data_size))

# 数据加载器 加载数据集
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64)

# 查看图像数据预处理后,标准化、规范化的数据
for img, label in train_dataloader :
    print("图像img的形状{},标签label的值{}".format(img.shape, label))
    print("图像数据预处理后:\n",img)

参考文章:
transforms

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-07 13:43:59  更:2022-02-07 13:45:26 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 11:20:30-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码