IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 李宏毅机器学习课程梳理【十三】:想到新优化算法的思路 -> 正文阅读

[人工智能]李宏毅机器学习课程梳理【十三】:想到新优化算法的思路

摘要

SGDM与Adam算法效果较好,很长一段时间以来未被明显地超越,Adam的强大不仅源于它自适应调节 η \eta η时用均方根和参数计算,可调整当前梯度与过去累加的梯度计算时占比,从而更精准地更新参数;还源于Momentum不易在梯度碰巧为0的局部最低点和梯度接近于0的平坦表面卡住。通过分析它们在训练集、验证集和测试集上的准确性以及Perplexity(混乱度),观察各自特点,思考改进算法的方向。

1 如何想到Adam?

对于SGD,从Learning rate的调节入手,优化出Adagrad算法: θ t = θ t ? 1 ? η ∑ i = 0 t ? 1 ( g i ) 2 g t ? 1 \theta_t=\theta_{t-1}-\dfrac{\eta}{\sqrt{\textstyle\sum_{i=0}^{t-1}(g_i)^2}}g_{t-1} θt?=θt?1??i=0t?1?(gi?)2 ?η?gt?1?,此算法可以防止在梯度过大的某一段,步伐过大,导致移动到更差的位置。却又出现分母单调递增,如果梯度大,没走几步就会卡住的问题。后来,优化出RMSProp算法: θ t = θ t ? 1 ? η σ t g t \theta_t=\theta_{t-1}-\dfrac{\eta}{\sigma^t}g^t θt?=θt?1??σtη?gt, σ t = α ( σ t ? 1 ) 2 + ( 1 ? α ) ( g t ) 2 \sigma^t=\sqrt{\alpha(\sigma^{t-1})^2+(1-\alpha)(g^t)^2} σt=α(σt?1)2+(1?α)(gt)2 ?,此算法通过调节 α \alpha α来控制过去的梯度与当前的梯度对Learning rate影响的程度,使Adaptive Learning Rate更精准。

然而Adaptive Learning Rate只能改变梯度大小,不能改变方向,移动方向全部由梯度决定,这样依然容易在plateau处更新缓慢、卡在saddle point和local minima的地方。在更新参数时添加一个类似惯性的量,优化出Momentum算法:定义vector m t + 1 m_{t+1} mt+1?为第0步到第t+1步梯度的累加。在计算movement时, m t = λ m t ? 1 ? η ? L ( θ t ? 1 ) m_t=\lambda m_{t-1}-\eta \nabla L(\theta^{t-1}) mt?=λmt?1??η?L(θt?1),可见移动方向可以受之前movement的影响,有可能会冲过local minima而找到global minima。

将两者结合,既可以adaptive learning rate加速收敛,又可以提高收敛于global minima的可能性。

2 如何训练出新算法?

第一步,分析效果最佳的Adam算法和SGDM算法,分析它们在训练集、验证集和测试集上的准确性以及Perplexity(混乱度),如图1、图2、图3和图4所示。
1
在训练集上Adam较快冲到第一名,SGDM位于第三名但是波动明显小。
2

在验证集上,SGDM准确率最高且稳定。
3
在测试集上,收敛后的SGDM准确率高且稳定。
4
Perplexity表示混乱度,可以看出Adam训练速度快,但是后期准确率不如其他算法。

  • Adam : fast training, large generalization gap, unstable
  • SGDM : stable, little generalization gap, better convergence
  • generalization指在训练集与测试集上结果的差距

第二步,寻找Adam和SGDM准确率高以及形成它们各自特点的原因。
有观点认为,可能Adam和SGDM抢到了两个极端的位置。图5给出了两个算法的generalization gap有很大差距的直观解释。
5
图5的横坐标表示参数,纵坐标表示Loss function的值,由于Training Set与Testing Set的差异,两条曲线形状相近。取到Flat Minimum的算法会有较小generalization gap,取到Sharp Minimum的算法会有较大的generalization gap。

第三步,做Trouble shooting,对Adam改进,让Adam收敛得又稳又好,如消除数值小但数量多的梯度覆盖了数值大的梯度问题的AMSGrad算法(目前一些改进算法或只顾及learning rate过大的情况,或只顾及learning rate过小的情况,或改进后分母单增);对SGDM改进,在算法中增加改变 η \eta η的设计。

3 正则化Weight Decay&L2 Regularization

两者的区别仅是正则化的那一项影不影响 v ^ t , m ^ t \sqrt{\hat v_t}, \hat m_t v^t? ?,m^t?
Weight Decay算法正则化那一项不影响 v ^ t , m ^ t \sqrt{\hat v_t}, \hat m_t v^t? ?,m^t?,具体算法是:
SGDWM

  • θ t = θ t ? 1 ? m t ? r ? θ t ? 1 \theta_t=\theta_{t-1}-m_t-r\cdot\theta_{t-1} θt?=θt?1??mt??r?θt?1?
  • m t = λ ? m t ? 1 + η ( ? L ( θ t ? 1 ) ) m_t=\lambda\cdot m_{t-1}+\eta\big(\nabla L(\theta_{t-1})\big) mt?=λ?mt?1?+η(?L(θt?1?))

AdamW

  • m t = β 1 m t ? 1 + ( 1 ? β 1 ) ? L ( θ t ? 1 ) m_t=\beta_1m_{t-1}+(1-\beta_1)\nabla L(\theta_{t-1}) mt?=β1?mt?1?+(1?β1?)?L(θt?1?)
  • v t = β 2 v t ? 1 + ( 1 ? β 2 ) ( ? L ( θ t ? 1 ) ) 2 v_t=\beta_2v_{t-1}+(1-\beta_2)\big(\nabla L(\theta_{t-1})\big)^2 vt?=β2?vt?1?+(1?β2?)(?L(θt?1?))2
  • θ t = θ t ? 1 ? η ( 1 v ^ t + ? m ^ t ? r ? θ t ? 1 ) \theta_t=\theta_{t-1}-\eta(\dfrac{1}{\sqrt{\hat v_t}+\epsilon}\hat m_t-r\cdot\theta_{t-1}) θt?=θt?1??η(v^t? ?+?1?m^t??r?θt?1?)

L2 Regularization:
SGDM

  • θ t = θ t ? 1 ? λ m t ? 1 ? η ( ? L ( θ t ? 1 ) + r ? θ t ? 1 ) \theta_t=\theta_{t-1}-\lambda m_{t-1}-\eta\big(\nabla L(\theta_{t-1})+r\cdot\theta_{t-1}\big) θt?=θt?1??λmt?1??η(?L(θt?1?)+r?θt?1?)
  • m t = λ m t ? 1 + η ( ? L ( θ t ? 1 ) + r ? θ t ? 1 ) m_t=\lambda m_{t-1}+\eta \big(\nabla L(\theta_{t-1})+r\cdot\theta_{t-1}\big) mt?=λmt?1?+η(?L(θt?1?)+r?θt?1?)

Adam

  • m t = λ m t ? 1 + η ( ? L ( θ t ? 1 ) + r ? θ t ? 1 ) m_t=\lambda m_{t-1}+\eta\big(\nabla L(\theta_{t-1})+r\cdot\theta_{t-1}\big) mt?=λmt?1?+η(?L(θt?1?)+r?θt?1?)
  • v t = β 2 v t ? 1 + ( 1 ? β 2 ) ( ? L ( θ t ? 1 ) + r ? θ t ? 1 ) 2 v_t=\beta_2v_{t-1}+(1-\beta_2)\big(\nabla L(\theta_{t-1})+r\cdot\theta_{t-1}\big)^2 vt?=β2?vt?1?+(1?β2?)(?L(θt?1?)+r?θt?1?)2

4 Why Deep?

首先比较Deep Model与Shallow Model的表现,比较的前提是神经网络的参数数量一样多,结论是Deep Model的错误率比Shallow Model低一些。

——Modularization

此种结构化的架构支持代码复用,降低主函数的复杂度,Deep Model可以做到Modularization。如区分长头发女生、长头发男生、短头发女生和短头发男生的图像的任务,如果训练四个分类器,则长头发男生的分类器由于训练集小导致它很弱。Modularization训练两个基本分类器——区分男生女生和区分长短,这两个基本分类器都有足够数量的训练数据。DNN中每一个神经元可看作是一个基本分类器,第二层神经元利用第一层神经元的输出,只需少量数据就可完成复杂些的分类器训练。
6
神奇的是,在DNN中Modularization是机器从数据中自动学习的。
Modularization的好处是使模型变得简单,需要的训练数据数量变少,Deep Learning实际是降低对训练数据数量的要求。

回顾前面的文章,Deep的结构免于人们手工Feature Transform,每个隐含层的输出结果逐渐清晰。
7

5 总结与展望

这篇文章想记录一下研究算法的思路,以及深度学习的优势。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-27 11:51:08  更:2021-08-27 11:51:15 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 22:35:25-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码