IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习分类优化实战 -> 正文阅读

[人工智能]深度学习分类优化实战


近期做了一些与分类相关得实验,主要研究了模型有过过程中的一些优化手段,这里记录下,本文对相关模型和算法进行了实现并运行测试,整体来说,有的优化手段可以增加模型的准确率,有的可能没啥效果,总的记录如下文。本文使用得数据集为CIFAR-100 。

代码地址:传送门

一、优化策略

1、CIFAR-100 数据集简介

首先,我们需要拿到数据和明确我们的任务。这里以cifar-100为例,它是8000万个微小图像数据集的子集,他们由Alex Krizhevsky,Vinod Nair和Geoffrey Hinton收集。CIFAR -100数据集(100 个类别)是 Tiny Images 数据集的子集,由 60000 个 32x32 彩色图像组成。CIFAR-100 中的 100 个类分为 20 个超类。每个类有 600 张图像。每个图像都带有一个“精细”标签(它所属的类)和一个“粗略”标签(它所属的超类)。每个类有 500 个训练图像和 100 个测试图像。

简单来说,我们需要针对CIFAR-100 数据集,设计、搭建、训练机器学习模型,能够尽可能准确地分辨出测试数据地标签。

参考连接:

CIFAR100数据集介绍及使用方法

2、模型评估指标

对于分类模型,最主要的是看模型的准确率。当然,光从准确率不能完全评估模型的性能,我还需要从混淆矩阵来看每一类的分类情况,PR曲线分析我们模型的准确率和召回率,ROC曲线评估模型的泛化能力。具体实现可以参考本文代码utils/metric.py

  • 混淆矩阵

在这里插入图片描述

通过观察,可以看出模型对每一类都能很好的进行分类。

  • PR曲线

在这里插入图片描述

  • ROC曲线

在这里插入图片描述

3、数据!数据!数据!

3.1、数据增强

数据增强是解决过拟合一个比较好的手段,它的本质是在一定程度上扩充训练数据样本,避免模型拟合到训练集中的噪声,所以设计一个好的数据增强方案尤为必要。在CV任务中,常用的数据增强包括RandomCrop(随机扣取)、Padding(补丁)、RandomHorizontalFlip(随机水平翻转)、ColorJilter(颜色抖动)等。还有一些其他高级的数据增强技巧,比如RandomEreasing(随机擦除)、MixUp、CutMix、AutoAugment,以及最新的AugMix和GridMask等。在实际训练中,如何选择,需要以具体实验为主,主要需要参考一些优秀论文,借鉴何使用。在此次任务中我们除了一些常用的增强方法外,也选择了一些加分点的优化手段,然后通过选择实验对比,选择较合适的数据增强方案。具体实现utils/augment/augment.py

主要对比如下:

methodacc
RandomCrop+RandomHorizontalFlip+RandomRotation0.78
RandomCrop+RandomHorizontalFlip+RandomRotation+random_erase0.79
RandomCrop+RandomHorizontalFlip+RandomRotation+random_erase+autoaugment0.81

3.2、数据分布

本文使用的CIFAR-100数据集的每一个类属于数据比较均衡的,但在实际分类中,大多数是不均衡的长尾数据,这个时候需要减少这种不均衡对预测的影响。当然,除了长尾分布的影响,还有类间相似的影响,比如两个类比较接近,无论形状、大小或颜色等,需要算法进一步区分或尽量减少对分类的影响。常用的解决长尾分布手手段有:重采样(需要在不影响原始分布的情况,如异常检测,这种情况重采样会改变数据原始分布,反而会降低准确率,因为本来就是正/负样本多)、重新设计loss(如Focal loss、OHEM、Class Balanced Loss)、或者转化为异常检测以及One-class分类模型等。

对于多类别问题,同一张图片可能有多个类,此时传统的CE loss的设计就有一定缺陷了。因为在多标签分类中,一个数据点中可以有多个正确的类。因此,多标签分类问题的需要检测图像中存在的每个对象。而CE loss会尽可能拟合one-hot标签,容易造成过拟合,无法保证模型的泛化能力,同时由于无法保证标签百分百正确,可能存在一些错误标签,但模型也会拟合这些错误标签,由于以上原因,提出了标签平滑,为软标签,属于正则化的一种,可以防止过拟合。label smoothing标签平滑实现见utils/losses.py

参考链接:

样本不均衡、长尾分布问题的方法整理(文献+代码)

视觉分类任务中处理不平衡问题的loss比较

长尾分布分类问题解决方法

4、模型选择

模型的选择优先考虑最新最好的模型,可以参考传送门,选择合适的模型。这里,我选择的ResNet模型作为baseline backbone。

在这里插入图片描述

这里我们进行不同的模型比较,实验如下:

methodacc
resnet180.75
resnet500.78
resnet1010.79

可以看出模型越复杂,能提升我们的模型准确率。所以后续我们也选择了wideresnet这样的大的模型来训练这个对模型的准确率也有很大的提升。当然,后续还可以选择当前最新的transformer模型,如:VIT、Swin、CaiT等,作为我们的训练模型。

参考链接:

一文窥探近期大火的Transformer以及在图像分类领域的应用_果菌药的博客-程序员ITS401_transformer图像分类

Transformer小试牛刀(一):Vision Transformer

5、模型优化

5.1、学习率选择

我们通过枚举不同学习率下的loss值选择最优学习率(具体实现tool/lr_finder.py),绘制曲线如下:

在这里插入图片描述

通过观察可知,lr=0.1时loss最低,此时学习率最优。

5.2、优化器选择

对于深度学习来说,优化器比较多,如:SGD、Adagrad、Adadelta、RMSprop、Adam等。当然,也有最新的优化器,如:Ranger、SAM等(具体实现utils/optim.py)。

这里我们对不同的优化器比较,实验如下:

methodacc
SGD0.79
adam0.79
ranger0.65
SAM0.8311

通过观察可知,选择SAM优化器最优。

参考链接:

深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)

再也不用担心过拟合的问题了

5.3、学习率更新策略选择

这里我们选择warmup预热更新策略,具体实现utils/scheduler.py
在这里插入图片描述

5.4、loss选择

在前面的数据分析中,我们讨论了数据分布的问题,由于我们的数据是多分类问题,所以我们需要在交叉熵损失函数的基础上加入标签平滑,这样能够更好的训练,防止过拟合。

这里我们对不同的损失函数比较,实验如下:

methodacc
CE0.8311
smooth_CE0.833

6、整体思路

  • lr:
    • warmup (5 epoch)
    • cosine lr decay
    • lr=0.1
    • total epoch(200 epoch)
  • bs=128
  • aug:
    • Random Crop and resize
    • Random left-right flipping
    • Random rotation
    • AutoAugment
    • Normalization
    • Random Erasing
  • weight decay=5e-4 (bias and bn undecayed)
  • kaiming weight init
  • optimizer: SAM
  • loss: smooth_CE
  • TTA

我们初步训练resnet50作为基础模型,实验测试过程如下:

networkmethodacc
resnet18SGD+warmup+CE0.75
resnet50SGD+warmup+CE0.78
resnet101SGD+warmup+CE0.79
resnet50SGD+warmup+random_erase+CE0.79
resnet50SGD+warmup+random_erase+autoaugment+CE0.815
resnet50adam+warmup+random_erase+autoaugment+CE0.79
resnet50ranger+warmup+random_erase+autoaugment+CE0.65
resnet50SAM+warmup+random_erase+autoaugment+CE0.8311
resnet50SAM+warmup+random_erase+autoaugment+smooth_CE0.833
wideresnet40_10SAM+warmup+random_erase+autoaugment+smooth_CE0.840
wideresnet40_10SAM+warmup+random_erase+autoaugment+smooth_CE+TTA0.8437

通过实验,我们最终选择wideresnet40_10作为特征提取模型,实验过程中将Accuracy由78%提升到84.37%。

二、pytorch实战

  1. 安装要求

    • python3.6
    • pytorch1.6.0+cu101
    • tensorboard 2.2.2(optional)
  2. 运行tensorboard

$ cd pytorch-cifar100
$ mkdir runs
$ tensorboard --logdir='runs' --port=6006 --host='localhost'
  1. 训练模型
$ python train.py -gpu
  1. 测试模型
$ python test.py 

模型参考链接:

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-25 11:36:15  更:2022-05-25 11:36:27 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/1 22:29:22-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码