<机器学习中的梯度下降>-随机梯度下降(SGD)以及mini-batch \batch Gradient Descent
梯度下降是机器学习中的基石,我们可以利用梯度下降算法,对损失函数(loss function)求最值,来得到一个对于系统模型最好的参数或者说权重(weight)。
我们定义一个线性回归模型:
f
(
x
)
=
w
T
x
+
b
f(x) = w^Tx+b
f(x)=wTx+b
其中
x
=
(
x
1
;
x
2
;
.
.
.
;
x
d
)
T
x = (x1; x2 ;...;xd)^T
x=(x1;x2;...;xd)T , 是一个数据集中的样本,其中有
d
d
d 个特征输入。
w
w
w同理,
w
=
(
w
1
;
w
2
;
.
.
.
;
w
d
)
T
w = (w1;w2;...;wd)^T
w=(w1;w2;...;wd)T。
对于单个样本而言,在一般做法中,为方便叙述,我们将参数
w
w
w 与
b
b
b 进行合并,形成一个新的向量。
w
′
=
{
w
;
b
}
=
(
w
1
,
w
2
,
w
3
,
.
.
.
,
w
d
,
b
)
T
w' = \{w;b\} = (w_1,w_2,w_3,...,w_d,b)^T
w′={w;b}=(w1?,w2?,w3?,...,wd?,b)T
再将
x
x
x 加一维 1:
x
′
=
[
x
1
T
x
2
T
x
3
T
.
.
.
x
d
T
1
]
x' = \begin{bmatrix} x^T_1 \\ x^T_2 \\ x^T_3 \\ ... & \\ x^T_d \\ 1 \end{bmatrix}
x′=?????????x1T?x2T?x3T?...xdT?1???????????
所以原模型可以简化为:
f
(
X
)
=
w
′
T
x
f(X) = w'^Tx
f(X)=w′Tx
故可定义损失函数(loss function):
L
(
w
)
=
(
y
?
w
′
T
x
)
2
L(w) = (y - w'^Tx)^2
L(w)=(y?w′Tx)2
y
y
y 为训练集中样本的真实输出,即标签。
那么对于整个训练集的成本函数(cost function):
C
(
w
)
=
1
n
∑
i
=
0
n
L
(
w
)
=
1
n
∑
i
=
0
n
(
y
?
w
′
T
x
)
2
C(w) = \frac{1}{n}\displaystyle\sum_{i = 0}^nL(w) = \frac{1}{n}\displaystyle\sum_{i = 0}^n(y - w'^Tx)^2
C(w)=n1?i=0∑n?L(w)=n1?i=0∑n?(y?w′Tx)2
上式中:
-
n
=
1
n = 1
n=1 时, 即 stochastic gradient descent 随机梯度下降**(SGD)**
-
1
<
n
<
m
1 < n < m
1<n<m(
m
m
m 是整个训练集的大小,样本数量),则是mini-batch Gradient Descent 为小批量梯度下降(MBGD)
-
n
=
m
n = m
n=m 时即batch Gradient Descent 批量梯度下降(BGD)
优缺点
BGD
在梯度下降的每一次迭代中,都用到了所有的训练样本。
优点:
- 由全数据集确定的方向能够更好地代表样本总体,从而更准确地朝向极值所在的方向。当目标函数为凸函数时,一定能收敛到全局最小值,如果目标函数非凸则收敛到局部最小值。
- 它对梯度的估计是无偏的。样例越多,标准差越低。
- 一次迭代是对所有样本进行计算,此时利用向量化进行操作,实现了并行。
缺点
- 训练成本太高,若 数据集规模很大,那么对于每一次迭代,机器的计算成本都会很高
SGB
与批量梯度下降相比,在每一次迭代中,值采用一个样本来进行梯度下降。
优点:
- 由于不是在全部训练数据上的损失函数,而是在每轮迭代中,随机优化某一条训练数据上的损失函数,这样每一轮参数的更新速度大大加快。
缺点:
- 每次迭代并不一定都是模型整体最优化的方向。若样本噪声较多,很容易陷入局部最优解而收敛到不理想的状态。
MBGD
大多数情况会使用这种梯度下降方式,而其实对于大多数机器学习框架,采用的SGD 其实也是基于 MBGD 来做的。
综合了 极端SGD 和 BGD的优缺点。
除了梯度方向
在机器学习梯度下降中,除了在利用样本来计算梯度得到下降方向外,更多还会采用学习率衰减的方式。
在梯度下降初期,能接受较大的步长(学习率),以较快的速度进行梯度下降。当收敛时,我们希望步长小一点,并且在最小值附近小幅摆动。
若一开始令学习率很高,虽然模型会收敛的很快,但很容易会在收敛点附件摆动,无法得到一个有效值,若一开始令学习率很低,那么模型的收敛速度会大幅度降低,因此我们更希望得到一个变化的学习率。
目前最常用的策略是:Adam ,为RMSProp + Momentum。
看过李宏毅老师的课,对其有大概的认知,但详细的Adam ,需要之后再深入了解。
|