| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> GBDT算法原理讲解以及常用的训练框架汇总:XGBoost LightGBM CatBoost NGBoost -> 正文阅读 |
|
[人工智能]GBDT算法原理讲解以及常用的训练框架汇总:XGBoost LightGBM CatBoost NGBoost |
目录1 基础知识点1.1 Ensemble Learning集成学习 (Ensemble Learning) 是机器学习 (Machine Learning) 里面核心的概念。它的主要思想归纳起来就是:通过训练多个弱学习器来达到一个强学习器的效果,组合后的表现比任何一个弱学习器的效果都好。在机器学习里error可以大概分两类:一类是偏差错误 (bias error):指的是预测值与真实值的差异,另外一类是方差错误 (variance error):指的是预测值作为随机变量的离散程度,而集成学习正好可以降低这些问题。通过组合多个分类器的结果可以降低模型预测的偏差,特别是对于一些不稳定的学习器,所以集成学习出来的学习器有更高的稳定性。而在集成学习中,常见的方法是Baging和Boosting,接下来我们先对这两个方法进行简单描述。 1.2 Bagging and Boosting使用Bagging或者Boosting技术,我们必须选择一个base学习器。例如,我们可以选择分类tree,那么Bagging和Boosting会组合成一系列树学习器变成一个集成的学习器,接下来,Bagging和Boosting怎么训练得到N个学习器呢?
在Boosting算法中,每个分类器的训练数据选择,依赖于上个分类器的预测结果,所以在每个训练步骤中,样本的权重是会重新调整的,其中预测错误的数据会增大权重,在概率上会更有可能被选中进入下个分类器进行训练,重点关注这些hard sample的识别。
1.3 Adaptive BoostingAdaptive Boosting (AdaBoost) 是一种Boosting方法,Boosting核心思想就是从上一个模型错误中进行学习。而AdaBoost学习方法主要是通过对分错的样本加大权重,让下一个模型更关注分错的样本识别效果。训练的基本步骤如下:
1.3 Gradient Boosting梯度提升 (Gradient Boosting) 也是一种Boosting方法,在上面我们提到,boosting模型的核心点是从先前的错误中进行学习,而Gradient Boosting每次迭代通过直接拟合上一步的残差 (目标损失函数对输出值的偏导数) ,使得当前t步的预测结果等于目标损失函数对上一步t-1预测值的负梯度方向,从而可以通过每次迭代 ( f t ( x i ) = f t ? 1 ( x i ) ? ? L ( y i , f t ? 1 ( x i ) ) ? f t ? 1 ( x i ) f_t(x_i) = f_{t-1}(x_i)-\frac{\partial L(y_i,f_{t-1}(x_i))}{\partial f_{t-1}(x_i)} ft?(xi?)=ft?1?(xi?)??ft?1?(xi?)?L(yi?,ft?1?(xi?))?)不断降低目标损失loss,算法流程如下: 1. 初 始 化 : f 0 ( x ) = argmin γ ∑ i = 1 N L ( y i , γ ) 1.初始化: f_0(x) = \text{argmin}_\gamma \sum_{i=1}^N L(y_i, \gamma) 1.初始化:f0?(x)=argminγ?∑i=1N?L(yi?,γ) 2. for? t = 1 to? T : 2. \text{for} \text{ } t=1 \text{} \text{to} \text{ }T: 2.for?t=1to?T: ??????? ( a ) 计 算 负 梯 度 : y ^ i = ? ? L ( y i , f t ? 1 ( x i ) ) ? f t ? 1 ( x i ) , i = 1 , 2 , . . . N \text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }(a) 计算负梯度: \hat{y}_i =- \frac{\partial L(y_i,f_{t-1}(x_i))}{\partial f_{t-1}(x_i)}, i=1,2,...N ???????(a)计算负梯度:y^?i?=??ft?1?(xi?)?L(yi?,ft?1?(xi?))?,i=1,2,...N ??????? ( b ) 通 过 最 小 化 平 方 误 差 , 用 基 学 习 器 h t ( x ) 拟 合 y ^ i , \text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }(b)通过最小化平方误差,用基学习器h_t(x)拟合\hat{y}_i, ???????(b)通过最小化平方误差,用基学习器ht?(x)拟合y^?i?, ?????????? w t = argmin w ∑ i = 1 N L ( y ^ i ? h t ( x i ; w ) ] 2 \text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ } w_t = \text{argmin}_w \sum_{i=1}^N L(\hat{y}_i - h_t(x_i; w)]^2 ??????????wt?=argminw?∑i=1N?L(y^?i??ht?(xi?;w)]2 ??????? ( c ) 使 用 L i n e s e a r c h 确 定 步 长 ρ m , 以 使 L 最 小 \text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }(c) 使用Linesearch确定步长\rho_m,以使L最小 ???????(c)使用Linesearch确定步长ρm?,以使L最小, ?????????? ρ t = argmin ρ ∑ i = 1 N L ( y i , f t ? 1 ( x i ) + ρ h t ( x i ; w t ) ) \text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ } \rho_t = \text{argmin}_{\rho} \sum_{i=1}^N L(y_i, f_{t-1}(x_i) + \rho h_t(x_i;w_t)) ??????????ρt?=argminρ?∑i=1N?L(yi?,ft?1?(xi?)+ρht?(xi?;wt?)) ??????? ( d ) f t ( x ) = f t ? 1 ( x ) + ρ t h t ( x ; w t ) \text{ }\text{ }\text{ }\text{ }\text{ }\text{ }\text{ }(d) f_t(x) = f_{t-1}(x)+\rho_th_t(x;w_t) ???????(d)ft?(x)=ft?1?(x)+ρt?ht?(x;wt?)
3.
输
出
f
M
(
x
)
3.输出 f_M(x)
3.输出fM?(x) 2 GBDT算法2.1 原理GBDT (Gradient Boosting Decision Tree) 是梯度提升树,接下来我们详细推导算法的细节过程。 1)GBDT预测结果值表达 2)定义目标损失函数
O
b
j
=
∑
i
=
1
n
l
(
y
i
,
y
i
^
)
+
∑
k
=
1
K
Ω
(
f
k
)
Obj = \sum_{i=1}^n l(y_i, \hat{y_i}) + \sum_{k=1}^K\Omega(f_k)
Obj=i=1∑n?l(yi?,yi?^?)+k=1∑K?Ω(fk?) 3)目标损失函数变形
y
i
^
(
0
)
=
0
\hat{y_i}^{(0)} = 0
yi?^?(0)=0 从上述公式可以看出,在t步,最终的预测结果是累积前面t-1步的所有结果与当前的树的结果之和(这里暂时先不考虑学习率因子控制每棵树的权值分值的缩放)。我们对目标函数进行展开如下:
O
b
j
(
t
)
=
∑
i
=
1
n
(
y
i
?
(
y
i
^
(
t
?
1
)
+
f
t
(
x
i
)
)
)
2
+
Ω
(
f
t
)
+
c
o
n
s
t
Obj^{(t)} = \sum_{i=1}^n (y_i - (\hat{y_i}^{(t-1)} + f_t(x_i)))^2 + \Omega(f_t) + const
Obj(t)=∑i=1n?(yi??(yi?^?(t?1)+ft?(xi?)))2+Ω(ft?)+const 4)泰勒公式 5)目标函数用泰勒展开式表示 6)正则化项
Ω
(
f
t
)
\Omega(f_t)
Ω(ft?) 则
Ω
\Omega
Ω计算结果为:
γ
3
+
1
2
λ
(
4
+
0.01
+
1
)
\gamma^3 + \frac{1}{2}\lambda(4+0.01+1)
γ3+21?λ(4+0.01+1),在前面的目标损失函数中,我们知道
f
t
(
x
)
f_t(x)
ft?(x)是模型结果的预测分值,那么对于树模型结构我们令: 我们定义: 得到了最终的目标损失函数表达,我们要求解的是
g
i
g_i
gi?和
h
i
h_i
hi?,其中
G
j
G_j
Gj?和
H
j
H_j
Hj?分别表示的是在叶子节点
j
j
j的所有样本的一阶导数
g
g
g分数之和以及二阶导数
h
h
h的分数之和。所以对于任何一种结构树,样本所到达的叶子节点,我们都可以计算出目标损失函数值,如下图所示: 7)如何求解一颗树的最优分割结构 2.2 训练熟悉GBDT的原理后,接下来我们来看下GBDT模型是如何训练的,以及训练完后,我们得到的到底是什么样的模型结构 训练步骤
模型结构 2.3 预测当训练好了模型以后,预测的过程就简单了,假设有T棵树,则最终的模型预测结果为这个样本落入到每棵树的叶子节点分值之和,用公式表达如下: 3 训练框架接下来介绍优化Gradient Boosting算法的几种分布式训练框架,这些框架支持分布式训练,树的调优,缺失值处理,正则化等避免过拟合问题。 3.1 XGBoostXGBoost: A Scalable Tree Boosting System 是由2014年5月,由DMLC开发出来的,目前是比较受欢迎,高效分布式训练Gradient Boosted Trees算法框架,包含的详细资料可以参考官方文档: 官网文档。 3.2 LightGBMLightGBM: A Highly Efficient Gradient Boosting Decision Tree 是由微软团队于2017年1月针对XGBoost框架存在的一些问题,设计出的更加高效的学习框架,主要基于梯度单边采样GOSS((Gradient Based One Side Sampling)以及互斥特征合并EFB(Exclusive Feature Bundling)来加快模型的学习效率,详细资料可以参考官方文档 官网文档。下面提供的是一个简单的在基于lightgbm做rank排序代码:
3.3 CatBoostCatBoost: unbiased boosting with categorical features 是由2017年4月,俄罗斯搜索巨头Yandex开发出的优化xgboost的一个框架,该框架最大的优势就是能够处理categorical features,相比LightGBM,不需要对分类特征进行label encoding,方便用户的快速操作。详细的官方网址:官网网址,下面是三者的详细对比图:
3.4 NGBoostNGBoost: Natural Gradient Boosting for Probabilistic Prediction 是一个比较新的训练Gradient Boosting算法框架。于2019年10月由斯坦福吴恩达团队公开发表。github代码记录在:NGBoost Github,核心点在于它使用自然梯度提升,一种用于概率预测的模块化提升算法。该算法由?基学习?器、?参数概率分布?和?评分规则?组成。 4 树模型与深度模型结合树模型有较强的解释性和稳定性,但是缺陷就是无语义特征,泛化能力不够强,所以在实际场景中,可以将树模型与深度学习模型结合,比如简单的求个平均可能都有一定提升,这是一个简单试验对比效果,bert比gbdt模型在情感分类任务中有一定提升,但将bert和gbdt结果融合到一起简单求一个平均值,效果最佳,具体可以参考有人做的一个简单对比试验:bert vs catboost 6 参考资料 |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 | -2025/1/8 4:45:13- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |