| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> 深度学习分类优化实战 -> 正文阅读 |
|
[人工智能]深度学习分类优化实战 |
文章目录近期做了一些与分类相关得实验,主要研究了模型有过过程中的一些优化手段,这里记录下,本文对相关模型和算法进行了实现并运行测试,整体来说,有的优化手段可以增加模型的准确率,有的可能没啥效果,总的记录如下文。本文使用得数据集为CIFAR-100 。 代码地址:传送门 一、优化策略1、CIFAR-100 数据集简介首先,我们需要拿到数据和明确我们的任务。这里以cifar-100为例,它是8000万个微小图像数据集的子集,他们由Alex Krizhevsky,Vinod Nair和Geoffrey Hinton收集。CIFAR -100数据集(100 个类别)是 Tiny Images 数据集的子集,由 60000 个 32x32 彩色图像组成。CIFAR-100 中的 100 个类分为 20 个超类。每个类有 600 张图像。每个图像都带有一个“精细”标签(它所属的类)和一个“粗略”标签(它所属的超类)。每个类有 500 个训练图像和 100 个测试图像。 简单来说,我们需要针对CIFAR-100 数据集,设计、搭建、训练机器学习模型,能够尽可能准确地分辨出测试数据地标签。 参考连接: 2、模型评估指标对于分类模型,最主要的是看模型的准确率。当然,光从准确率不能完全评估模型的性能,我还需要从混淆矩阵来看每一类的分类情况,PR曲线分析我们模型的准确率和召回率,ROC曲线评估模型的泛化能力。具体实现可以参考本文代码
通过观察,可以看出模型对每一类都能很好的进行分类。
3、数据!数据!数据!3.1、数据增强数据增强是解决过拟合一个比较好的手段,它的本质是在一定程度上扩充训练数据样本,避免模型拟合到训练集中的噪声,所以设计一个好的数据增强方案尤为必要。在CV任务中,常用的数据增强包括RandomCrop(随机扣取)、Padding(补丁)、RandomHorizontalFlip(随机水平翻转)、ColorJilter(颜色抖动)等。还有一些其他高级的数据增强技巧,比如RandomEreasing(随机擦除)、MixUp、CutMix、AutoAugment,以及最新的AugMix和GridMask等。在实际训练中,如何选择,需要以具体实验为主,主要需要参考一些优秀论文,借鉴何使用。在此次任务中我们除了一些常用的增强方法外,也选择了一些加分点的优化手段,然后通过选择实验对比,选择较合适的数据增强方案。具体实现 主要对比如下:
3.2、数据分布本文使用的CIFAR-100数据集的每一个类属于数据比较均衡的,但在实际分类中,大多数是不均衡的长尾数据,这个时候需要减少这种不均衡对预测的影响。当然,除了长尾分布的影响,还有类间相似的影响,比如两个类比较接近,无论形状、大小或颜色等,需要算法进一步区分或尽量减少对分类的影响。常用的解决长尾分布手手段有:重采样(需要在不影响原始分布的情况,如异常检测,这种情况重采样会改变数据原始分布,反而会降低准确率,因为本来就是正/负样本多)、重新设计loss(如Focal loss、OHEM、Class Balanced Loss)、或者转化为异常检测以及One-class分类模型等。 对于多类别问题,同一张图片可能有多个类,此时传统的CE loss的设计就有一定缺陷了。因为在多标签分类中,一个数据点中可以有多个正确的类。因此,多标签分类问题的需要检测图像中存在的每个对象。而CE loss会尽可能拟合one-hot标签,容易造成过拟合,无法保证模型的泛化能力,同时由于无法保证标签百分百正确,可能存在一些错误标签,但模型也会拟合这些错误标签,由于以上原因,提出了标签平滑,为软标签,属于正则化的一种,可以防止过拟合。label smoothing标签平滑实现见 参考链接: 4、模型选择模型的选择优先考虑最新最好的模型,可以参考传送门,选择合适的模型。这里,我选择的ResNet模型作为baseline backbone。 这里我们进行不同的模型比较,实验如下:
可以看出模型越复杂,能提升我们的模型准确率。所以后续我们也选择了wideresnet这样的大的模型来训练这个对模型的准确率也有很大的提升。当然,后续还可以选择当前最新的transformer模型,如:VIT、Swin、CaiT等,作为我们的训练模型。 参考链接: 一文窥探近期大火的Transformer以及在图像分类领域的应用_果菌药的博客-程序员ITS401_transformer图像分类 Transformer小试牛刀(一):Vision Transformer 5、模型优化5.1、学习率选择我们通过枚举不同学习率下的loss值选择最优学习率(具体实现 通过观察可知,lr=0.1时loss最低,此时学习率最优。 5.2、优化器选择对于深度学习来说,优化器比较多,如:SGD、Adagrad、Adadelta、RMSprop、Adam等。当然,也有最新的优化器,如:Ranger、SAM等(具体实现 这里我们对不同的优化器比较,实验如下:
通过观察可知,选择SAM优化器最优。 参考链接: 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam) 5.3、学习率更新策略选择这里我们选择warmup预热更新策略,具体实现 5.4、loss选择在前面的数据分析中,我们讨论了数据分布的问题,由于我们的数据是多分类问题,所以我们需要在交叉熵损失函数的基础上加入标签平滑,这样能够更好的训练,防止过拟合。 这里我们对不同的损失函数比较,实验如下:
6、整体思路
我们初步训练resnet50作为基础模型,实验测试过程如下:
通过实验,我们最终选择wideresnet40_10作为特征提取模型,实验过程中将Accuracy由78%提升到84.37%。 二、pytorch实战
模型参考链接:
|
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 | -2025/1/1 22:29:22- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |