IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch cifar100数据集的简单理解与用法 -> 正文阅读

[人工智能]Pytorch cifar100数据集的简单理解与用法

https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR100.html#torchvision.datasets.CIFAR100

torchvision.datasets中提供了一些经典数据集,其中最为常用的是cifar10/100,mnist,在搓增量学习、领域自适应、主动学习等任务时经常需要打交道。这里我们以cifar100为例看一下其基本的用法。

首先,下载训练集与测试集:

from torchvision import datasets

train_dataset = datasets.cifar.CIFAR100(root='cifar100', train=True, transform=None, download=True)
test_dataset = datasets.cifar.CIFAR100(root='cifar100', train=False, transform=None, download=True)

可以看到有四个参数:

  • root:数据集文件的存储路径。
  • train:是否为训练集。True则表示视为训练集,False表示视为测试集。
  • transform:所应用的数据扩充方法。
  • download:是否下载。如果为True且root路径下无相应的数据集文件,则自动从互联网上下载数据集至给定路径。

到这里,严格来讲就算介绍完毕了,因为这里得到的train_datasettest_dataset都属于torch.utils.data.Dataset对象,剩下来的用法和我们自己手工封装数据集的一致。现在,我们着重考察下cifar100数据集的结构。

首先,直接用下标去访问一个数据集对象:

print(train_dataset[0])

输出结果如下:

(<PIL.Image.Image image mode=RGB size=32x32 at 0x25568409670>, 19)

可以看到得到的是一个tuple,第一项为image,第二项为ground truth,即图像所属的分类,使用一个int值表示。在训练计算损失函数的时候,直接使用F.cross_entropy即可,而不需要考虑将int标签转化成独热向量的形式。而对于image,在实际训练中读数据集时transform一项必然带有transforms.ToTensor(),在这种情况下返回的则是一个Tensor向量以便网络训练。

此外,还有另一种访问方法:

print(train_dataset.data[0])
print(train_dataset.targets[0])

输出结果如下:

[[[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
              ]]]
19

此时,标签仍为int不变,而数据返回的形式为ndarray:

print(train_dataset.data[0].shape)

输出结果如下:

(32, 32, 3)

可以看到cifar100图像的尺寸为32×32×3(3表示通道数)。这里我们也顺便将这张图像展示:

train_dataset[0][0].show()

在这里插入图片描述
因为只有1024像素所以非常小。查阅类别表可以发现类-19对应cattle(牛),这也与我们的观察相吻合。

现在我们来看另一个问题,可以发现第一张图像的类为19,那么这就表明,整个数据集50000张训练图像并不是按类别顺序进行划分的,即甚至可以在使用dataloader时不打开shuffle。为了验证这一点,我们直接输出train_dataset.targets中的前10个标签:

print(train_dataset.targets[:10])

结果如下:

[19, 29, 0, 11, 1, 86, 90, 28, 23, 31]

可以发现确实是乱序的。但是在一些奇怪的任务里面,会要求类别是有序的(实际上我们自己做的数据集也是尽量有序的,需要打乱通过dataloader实现即可),那么这里就看一下怎么去弄。具体来说,肯定是从targets中所包含的类别信息入手。首先将其从list转为ndarray方便我们使用numpy去操作:

train_targets = np.array(train_dataset.targets)

那么,比方说我们要取出类别在[10, 19]内的全部样本,就可以把相应的target先取出来。这里先得到一个长度为50000的包含每个元素是否满足条件的列表:

idx = np.logical_and(train_targets >= 10, train_targets < 20)
print(len(idx))
print(idx)

输出:

50000
[ True False False ... False False False]

然后使用np.where()进行广播,获得具体的下标:

idx = np.where(idx)
print(len(idx[0]))
print(idx[0])

可以看到满足要求的target所对应的下标共有5000个,确实是十分之一:

5000
[    0     3    13 ... 49977 49981 49991]

最后将这些下标作为索引取出相应的子数据集即可:

train_data = np.array(train_dataset.data)
selected_data = train_data[idx[0]]
selected_target = train_targets[idx[0]]
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-09 18:22:35  更:2022-04-09 18:24:57 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:47:58-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码