IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Resnet实现CIFAR-10图像分类 —— Mindspore实践 -> 正文阅读

[人工智能]Resnet实现CIFAR-10图像分类 —— Mindspore实践

????????计算机视觉是当前深度学习研究最广泛、落地最成熟的技术领域,在手机拍照、智能安防、自动驾驶等场景有广泛应用。从2012年AlexNet在ImageNet比赛夺冠以来,深度学习深刻推动了计算机视觉领域的发展,当前最先进的计算机视觉算法几乎都是深度学习相关的。深度神经网络可以逐层提取图像特征,并保持局部不变性,被广泛应用于分类、检测、分割、跟踪、检索、识别、提升、重建等视觉任务中。 结合图像分类任务,了解MindSpore如何应用于计算机视觉场景,如何训练模型,得出一个性能较优的模型。

????????CIFAR-10 是一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片。?
????????下面这幅图列举了10各类,每一类展示了随机的10张图片:


与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:
? CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
? CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
? 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。 直接的线性模型如 Softmax 在 CIFAR-10 上表现得很差。

????????图像分类是最基础的计算机视觉应用,属于有监督学习类别。给定一张数字图像,判断图像所属的类别,如猫、狗、飞机、汽车等等。用函数来表示这个过程如下:

????????定义的分类函数,以图片数据image为输入,通过model方法对image进行分类,最后返回分类结果。选择合适的model是关键。这里的model一般指的是深度卷积神经网络,如AlexNet、VGG、GoogLeNet、ResNet等等。
????????下面按照MindSpore的训练数据模型的正常步骤进行,当使用到MindSpore或者图像分类操作时,会增加相应的说明,整体流程如下:

  1. 数据集的准备,这里使用的是CIFAR-10数据集。

  2. 构建一个卷积神经网络,这里使用ResNet-50网络。

  3. 定义损失函数和优化器。

  4. 调用Model高阶API进行训练和保存模型文件。

  5. 进行模型精度验证。

训练数据集下载?

import mindspore
print(mindspore.__version__)

数据集准备?

!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/cifar10.zip
!unzip -o cifar10.zip -d ./datasets
!tree ./datasets/cifar10

????????数据集处理对于训练非常重要,好的数据集可以有效提高训练精度和效率。在加载数据集前,通常会对数据集进行一些处理。这里用到了数据增强,数据混洗和批处理。

????????数据增强主要是对数据进行归一化和丰富数据样本数量。常见的数据增强方式包括裁剪、翻转、色彩变化等等。MindSpore通过调用map方法在图片上执行增强操作。数据混洗和批处理主要是通过数据混洗shuffle随机打乱数据的顺序,并按batch读取数据,进行模型训练。

????????构建create_dataset函数,来创建数据集。通过设置 resize_heightresize_widthrescaleshift参数,定义map以及在图片上运用map实现数据增强。

import mindspore.nn as nn
from mindspore import dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore import context
import numpy as np
import matplotlib.pyplot as plt

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

def create_dataset(data_home, repeat_num=1, batch_size=32, do_train=True, device_target="GPU"):
    """
    create data for next use such as training or inferring
    """

    cifar_ds = ds.Cifar10Dataset(data_home,num_parallel_workers=8, shuffle=True)

    c_trans = []
    if do_train:
        c_trans += [
            C.RandomCrop((32, 32), (4, 4, 4, 4)),
            C.RandomHorizontalFlip(prob=0.5)
        ]

    c_trans += [
        C.Resize((224, 224)),
        C.Rescale(1.0 / 255.0, 0.0),
        C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        C.HWC2CHW()
    ]

    type_cast_op = C2.TypeCast(mstype.int32)

    cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=8)

    cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
    cifar_ds = cifar_ds.repeat(repeat_num)

    return cifar_ds


ds_train_path = "./datasets/cifar10/train/"
dataset_show = create_dataset(ds_train_path)
with open(ds_train_path+"batches.meta.txt","r",encoding="utf-8") as f:
    all_name = [name.replace("\n","") for name in f.readlines()]

iterator_show= dataset_show.create_dict_iterator()
dict_data = next(iterator_show)
images = dict_data["image"].asnumpy()
labels = dict_data["label"].asnumpy()
count = 1
%matplotlib inline
for i in images:
    plt.subplot(4, 8, count)
    # Images[0].shape is (3,224,224).We need transpose as (224,224,3) for using in plt.show().
    picture_show = np.transpose(i,(1,2,0))
    picture_show = picture_show/np.amax(picture_show)
    picture_show = np.clip(picture_show, 0, 1)
    plt.title(all_name[labels[count-1]])
    picture_show = np.array(picture_show,np.float32)
    plt.imshow(picture_show)
    count += 1
    plt.axis("off")

print("The dataset size is:", dataset_show.get_dataset_size())
print("The batch tensor is:",images.shape)
plt.show()

????????数据集生成后,选取一个batch的图像进行可视化查看,经过数据增强后,原数据集变成了每个batch张量为,共计1572个batch的新数据集。?

定义卷积神经网络

????????卷积神经网络已经是图像分类任务的标准算法了。卷积神经网络采用分层的结构对图片进行特征提取,由一系列的网络层堆叠而成,比如卷积层、池化层、激活层等等。 ResNet-50通常是较好的选择。首先,它足够深,常见的有34层,50层,101层。通常层次越深,表征能力越强,分类准确率越高。其次,可学习,采用了残差结构,通过shortcut连接把低层直接跟高层相连,解决了反向传播过程中因为网络太深造成的梯度消失问题。此外,ResNet-50网络的性能很好,既表现为识别的准确率,也包括它本身模型的大小和参数量。

下载构建好的resnet50网络源码文件。

!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/source-codes/resnet.py

下载下来的resnet.py在当前目录,可以使用import方法将resnet50网络导出。?

from resnet import resnet50

net = resnet50(batch_size=32, num_classes=10)

定义损失函数和优化器

????????接下来需要定义损失函数(Loss)和优化器(Optimizer)。损失函数是深度学习的训练目标,也叫目标函数,可以理解为神经网络的输出(Logits)和标签(Labels)之间的距离,是一个标量数据。 常见的损失函数包括均方误差、L2损失、Hinge损失、交叉熵等等。图像分类应用通常采用交叉熵损失(CrossEntropy)。 优化器用于神经网络求解(训练)。由于神经网络参数规模庞大,无法直接求解,因而深度学习中采用随机梯度下降算法(SGD)及其改进算法进行求解。MindSpore封装了常见的优化器,如SGD、ADAM、Momemtum等等。本例采用Momentum优化器,通常需要设定两个参数,动量(moment)和权重衰减项(weight decay)。

????????通过调用MindSpore中的API:MomentumSoftmaxCrossEntropyWithLogits,设置损失函数和优化器的参数。

import mindspore.nn as nn
from mindspore.nn import SoftmaxCrossEntropyWithLogits

ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)

调用Model高阶API进行训练和保存模型文件

????????完成数据预处理、网络定义、损失函数和优化器定义之后,就可以进行模型训练了。模型训练包含两层迭代,数据集的多轮迭代(epoch)和一轮数据集内按分组(batch)大小进行的单步迭代。其中,单步迭代指的是按分组从数据集中抽取数据,输入到网络中计算得到损失函数,然后通过优化器计算和更新训练参数的梯度。

????????为了简化训练过程,MindSpore封装了Model高阶接口。用户输入网络、损失函数和优化器完成Model的初始化,然后调用train接口进行训练,train接口参数包括迭代次数epoch和数据集dataset

????????模型保存是对训练参数进行持久化的过程。Model类中通过回调函数的方式进行模型保存,如下面代码所示。用户通过CheckpointConfig设置回调函数的参数,其中,save_checkpoint_steps指每经过固定的单步迭代次数保存一次模型,keep_checkpoint_max指最多保存的模型个数。

????????本次选择epoch_size为10,一共迭代了10次,大约耗时25分钟,得到如下的运行结果。可以自行设置不同的epoch_size,生成不同的模型,在下面的验证部分查看模型精确度。

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore import load_checkpoint, load_param_into_net
import os
from mindspore import Model


model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
# As for train, users could use model.train

epoch_size = 10
ds_train_path = "./datasets/cifar10/train/"
model_path = "./models/ckpt/mindspore_vision_application/"
os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path))

dataset = create_dataset(ds_train_path )
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory=model_path, config=config_ck)
loss_cb = LossMonitor(142)
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
epoch: 1 step: 1562, loss is 1.2250829
epoch: 2 step: 1562, loss is 0.948782
epoch: 3 step: 1562, loss is 1.02575
epoch: 4 step: 1562, loss is 0.8370316
epoch: 5 step: 1562, loss is 0.65224147
epoch: 6 step: 1562, loss is 0.5031056
epoch: 7 step: 1562, loss is 0.39631012
epoch: 8 step: 1562, loss is 0.21934134
epoch: 9 step: 1562, loss is 0.35878238
epoch: 10 step: 1562, loss is 0.34452274?

查询训练过程中,保存好的模型。

!tree ./models/ckpt/mindspore_vision_application/

每1562个step保存一次模型权重参数.ckpt文件,一共保存了10个,另外.meta文件保存模型的计算图信息。

进行模型精度验证

调用model.eval得到最终精度超过0.80,准确度较高,验证得出模型是性能较优的。

# As for evaluation, users could use model.eval
ds_eval_path = "./datasets/cifar10/test/"
eval_dataset = create_dataset(ds_eval_path, do_train=False)
res = model.eval(eval_dataset)
print("result: ", res)
result:  {'acc': 0.8165064102564102}?
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-08-19 19:04:57  更:2022-08-19 19:05:43 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 18:34:08-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计