IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> CIFAR-10代码解读 -> 正文阅读

[人工智能]CIFAR-10代码解读

文章目录

1. 代码

# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: CF10_test
# @Create time: 2022/1/8 17:08

# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: CF10_test
# @Create time: 2022/1/8 17:08


# 1. 导入相关数据库
import collections
import math
import os
import shutil
import torch
import torchvision
from torch import nn
from torch import utils
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 2.下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
								'2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')

demo = True

if demo:
	data_dir = d2l.download_extract('cifar10_tiny')
else:
	data_dir = '../data/cifar-10/'


# 3. 读取csv文件,按行读取csv文件,返回 (name,label)数据对
def read_csv_labels(fname):
	"""

	:param fname: 需要读取的csv文件的地址
	:return: 返回字典对 ,dict {name,label}
	"""
	# 读取csv文件
	with open(fname, 'r') as f:
		# 将整行读进来,因为第一行为标题,所以从第二行进行读取
		lines = f.readlines()[1:]
	# 因为每行有一个 \n ,所以要先删除,后根据 ","进行分割得到 list
	tokens = [l.rstrip().split(',') for l in lines]
	# 逐个迭代token ,生成字典对
	return dict(((name, label) for name, label, in tokens))


# 4.将字典对返回给label,做好标签映射
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('# 训练样本:', len(labels))
print('# 类别 :', len(set(labels.values())))


# 5. 将源filename里面的文件拷贝到目标文件夹中
def copyfile(filename, target_dir):
	# 创建一个文件夹,exist_ok=True:表示文件夹存在的情况下不触发异常
	os.makedirs(target_dir, exist_ok=True)
	# 将源filename文件夹拷贝到目标文件夹target_dir中
	shutil.copy(filename, target_dir)


# 6. 将验证集从原始的训练集中拆分出来
def reorg_train_valid(data_dir, labels, valid_ratio):
	"""
	function : 将验证集从原始的训练集中拆分出来
	:param data_dir:  数据目录
	:param labels: 所有的标签
	:param valid_ratio: 验证集分割的比例
	:return:
	"""
	# 训练数据集中样本最少的类别中的样本数
	# n 为标签中频率最低的标签对应的频数
	n = collections.Counter(labels.values()).most_common()[-1][1]
	# 验证集中每个类别的样本数

	# 按照比例将部分数据作为验证集,n_valid_per_label =8
	n_valid_per_label = max(1, math.floor(n * valid_ratio))
	# 查看label的种类个数
	label_count = {}
	# 逐个遍历train文件夹里面图片的文件名
	for train_file in os.listdir(os.path.join(data_dir, 'train')):
		# 文件名的序号通过labels字典对的映射得到图片对应的实际标签,比如 1 在 csv 中表示 frog
		# label 表示图片的实际类别,如 frog
		label = labels[train_file.split('.')[0]]
		fname = os.path.join(data_dir, 'train', train_file)
		# 通过label名来创建文件夹,将对应的图片放到相应的文件夹中
		copyfile(fname, os.path.join(data_dir, 'train_valid_test',
									 'train_valid', label))
		if label not in label_count or label_count[label] < n_valid_per_label:
			copyfile(fname, os.path.join(data_dir, 'train_valid_test',
										 'valid', label))
			label_count[label] = label_count.get(label, 0) + 1
		else:
			copyfile(fname, os.path.join(data_dir, 'train_valid_test',
										 'train', label))
	return n_valid_per_label


# 7. 整理文件夹
def reorg_test(data_dir):
	"""
	在预测期间整理测试集,以方便读取
	:param data_dir:
	:return:
	"""
	# 遍历 test 文件夹,将不同的文件进行拷贝到另一个文件夹中
	#  data_dir/test  ->copy-> data_dir/train_valid_test/test/unkown
	for test_file in os.listdir(os.path.join(data_dir, 'test')):
		copyfile(os.path.join(data_dir, 'test', test_file),
				 os.path.join(data_dir, 'train_valid_test', 'test',
							  'unknown'))


# 8. 读取reorg_cifa10_data 文件
def reorg_cifa10_data(data_dir, valid_ratio):
	labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
	reorg_train_valid(data_dir, labels, valid_ratio)
	reorg_test(data_dir)


batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifa10_data(data_dir, valid_ratio)

# 9.训练集train预处理
transform_train = torchvision.transforms.Compose([
	# 将大小变成40像素的正方形
	torchvision.transforms.Resize(40),
	# 随机剪裁
	torchvision.transforms.RandomSizedCrop(32, scale=(0.64, 1.0),
										   ratio=(1.0, 1.0)),
	# 随机水平翻转
	torchvision.transforms.RandomHorizontalFlip(),
	# 图片转成张量
	torchvision.transforms.ToTensor(),
	# 正则化每个通道的值
	torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
									 [0.2023, 0.1994, 0.2010])])

# 10.测试集预处理
transform_test = torchvision.transforms.Compose([
	torchvision.transforms.ToTensor(),
	torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
									 [0.2023, 0.1994, 0.2010])])



# 11.按文件夹形式读取训练集
# image+label -> dataset -> dataloader
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(
	os.path.join(data_dir, 'train_valid_test', folder),
	transform=transform_train) for folder in ['train', 'train_valid']]

# 12.按文件夹形式读取测试集生成 dataset
valid_ds, test_ds = [torchvision.datasets.ImageFolder(
	os.path.join(data_dir, 'train_valid_test', folder),
	transform=transform_test) for folder in ['valid', 'test']]

# 13.从dataset转换成 dataloader
# 训练迭代器,验证迭代器,shuffle=打乱数据提高鲁棒,
train_iter, train_valid_iter = [torch.utils.data.DataLoader(
	dataset, batch_size, shuffle=True, drop_last=True)
	for dataset in (train_ds, train_valid_ds)]

valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,
										 drop_last=True)

test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,
										drop_last=False)


# 14.定义神经网络模型
def get_net():
	num_classes = 10
	net = d2l.resnet18(num_classes, 3)
	return net


# 15.定义损失函数
loss = nn.CrossEntropyLoss(reduction="none")


# 16.定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,
		  lr_decay):
	"""
	function : 定义训练函数
	:param net: 神经网络
	:param train_iter: 训练迭代器
	:param valid_iter: 验证迭代器
	:param num_epochs: 迭代的次数
	:param lr: 学习率
	:param wd: 权重衰减
	:param devices: GPU
	:param lr_period:
	:param lr_decay:
	:return:
	"""
	# 定义优化器 SGD 随机梯度下降
	trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,
							  weight_decay=wd)
	# 设置优化器trainer中学习率改变的方式,
	scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
	# 设置批量大小和计时器
	num_batches, timer = len(train_iter), d2l.Timer()
	# 设置画布上的显示值
	legend = ['train loss', 'train acc']
	# 如果有验证集,就增加到显示画布上
	if valid_iter is not None:
		legend.append('valid acc')
	# 画动图,x轴是迭代次数,
	animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
							legend=legend)
	# 实现模型上的数据并行
	net = nn.DataParallel(net, device_ids=devices).to(devices[0])
	for epoch in range(num_epochs):
		# 启动神经网络的训练模式
		net.train()
		# 定义累加器
		metric = d2l.Accumulator(3)
		# 从训练迭代器中成对的拿到 (features,labels)
		for i, (features, labels) in enumerate(train_iter):
			# 计时开始
			timer.start()
			# 根据训练集中的features,labels得到损失l和精度acc
			# 在GPU上进行小批量训练数据
			l, acc = d2l.train_batch_ch13(net, features, labels,
										  loss, trainer, devices)
			# 累加器增加数据
			metric.add(l, acc, labels.shape[0])
			# 计时结束
			timer.stop()
			if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
				animator.add(epoch + (i + 1) / num_batches,
							 (metric[0] / metric[2], metric[1] / metric[2], None))
		if valid_iter is not None:
			valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
			animator.add(epoch + 1, (None, None, valid_acc))
		# 调度器更新
		scheduler.step()
	measures = (f'train loss{metric[0] / metric[2]:.3f},'
				f'train acc {metric[1] / metric[2]:.3f}')
	if valid_iter is not None:
		measures += f',valid acc{valid_acc:3f}'
	print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
					 f'examples/sec on {str(devices)}')


# 17. 设置基本参数
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4

# 18.设置神经网络相关参数
lr_period, lr_decay, net = 4, 0.9, get_net()

# 19.启动训练
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)

# 20.显示结果
plt.show()

2.结果

train loss0.668,train acc 0.767,valid acc0.375000
581.3examples/sec on [device(type=‘cuda’, index=0)]
在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-01-11 23:59:55  更:2022-01-12 00:01:12 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 17:36:33-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码