Pytorch 数值稳定性,模型初始化
这一节内容有很多数学推导,大家可以多看看李沐老师的视频和教材理解理解。我摊牌了,这一章我没完全听懂。就大概记了下能大概听懂的内容,准备以后的学习中进一步加深对本节课的理解。
1. 数值稳定性
1.1 神经网络的梯度
考虑如下有
d
d
d 层神经网络:
h
t
=
f
t
(
h
t
?
1
)
?and?
y
=
?
°
f
d
°
…
°
f
1
(
x
)
\mathbf{h}^{t}=f_{t}\left(\mathbf{h}^{t-1}\right) \quad \text { and } \quad y=\ell \circ f_{d} \circ \ldots \circ f_{1}(\mathbf{x})
ht=ft?(ht?1)?and?y=?°fd?°…°f1?(x) 计算损失
?
\ell
? 关于
W
t
\mathbf{W}^{t}
Wt 的 的梯度:
?
l
?
W
t
=
?
l
?
h
d
?
h
d
?
h
d
?
1
…
?
h
t
+
1
?
h
t
?
h
t
?
W
t
\frac{\partial l}{\partial \mathbf{W}^{t}}=\frac{\partial l}{\partial \mathbf{h}^{d}} \frac{\partial \mathbf{h}^{d}}{\partial \mathbf{h}^{d-1}} \ldots \frac{\partial \mathbf{h}^{t+1}}{\partial \mathbf{h}^{t}} \frac{\partial \mathbf{h}^{t}}{\partial \mathbf{W}^{t}}
?Wt?l?=?hd?l??hd?1?hd?…?ht?ht+1??Wt?ht?
1.2 梯度爆炸和梯度消失
当每层梯度都是大于 1 的情况下,层数变多,最后得到的数值会越来越大。
当每层梯度都是小于 1 的情况下,层数变多,最后得到的数值会越来越小。
1.3 例子:MLP
加入如下 MLP (为了简便先不考虑偏置
b
b
b):
f
t
(
h
t
?
1
)
=
σ
(
W
t
h
t
?
1
)
?
h
t
?
h
t
?
1
=
diag
?
(
σ
′
(
W
t
h
t
?
1
)
)
(
W
t
)
T
∏
i
=
t
d
?
1
?
h
i
+
1
?
h
i
=
∏
i
=
t
d
?
1
diag
?
(
σ
′
(
W
i
h
i
?
1
)
)
(
W
i
)
T
f_{t}\left(\mathbf{h}^{t-1}\right)=\sigma\left(\mathbf{W}^{t} \mathbf{h}^{t-1}\right) \\ \frac{\partial \mathbf{h}^{t}}{\partial \mathbf{h}^{t-1}}=\operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{t} \mathbf{h}^{t-1}\right)\right)\left(W^{t}\right)^{T} \\ \prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T}
ft?(ht?1)=σ(Wtht?1)?ht?1?ht?=diag(σ′(Wtht?1))(Wt)Ti=t∏d?1??hi?hi+1?=i=t∏d?1?diag(σ′(Wihi?1))(Wi)T 其中
σ
\sigma
σ 是激活函数,
σ
′
\sigma^{\prime}
σ′ 是
σ
\sigma
σ 的导函数。
1.3.1 梯度爆炸
使用
R
e
L
U
ReLU
ReLU 作为激活函数:
σ
(
x
)
=
max
?
(
0
,
x
)
?and?
σ
′
(
x
)
=
{
1
?if?
x
>
0
0
?otherwise?
\sigma(x)=\max (0, x) \quad \text { and } \quad \sigma^{\prime}(x)= \begin{cases}1 & \text { if } x>0 \\ 0 & \text { otherwise }\end{cases}
σ(x)=max(0,x)?and?σ′(x)={10??if?x>0?otherwise??
∏
i
=
t
d
?
1
?
h
i
+
1
?
h
i
=
∏
i
=
t
d
?
1
diag
?
(
σ
′
(
W
i
h
i
?
1
)
)
(
W
i
)
T
\prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T}
∏i=td?1??hi?hi+1?=∏i=td?1?diag(σ′(Wihi?1))(Wi)T 的一些元素会来自于
∏
i
=
t
d
?
1
(
W
i
)
T
\prod_{i=t}^{d-1}\left(W^{i}\right)^{T}
∏i=td?1?(Wi)T。如果
d
?
t
d-t
d?t 很大,得到的数值将会很大。
梯度爆炸的问题:
- 值超出值域
- 对学习率敏感
- 若学习率太大->大参数值->更大的梯度
- 若学习率太小->训练无进展
- 我们可能需要在驯良过程中不断调整学习率
1.3.2 梯度消失
使用
s
i
g
m
o
i
d
sigmoid
sigmoid 作为激活函数:
σ
(
x
)
=
1
1
+
e
?
x
σ
′
(
x
)
=
σ
(
x
)
(
1
?
σ
(
x
)
)
\sigma(x)=\frac{1}{1+e^{-x}} \quad \sigma^{\prime}(x)=\sigma(x)(1-\sigma(x))
σ(x)=1+e?x1?σ′(x)=σ(x)(1?σ(x))
∏
i
=
t
d
?
1
?
h
i
+
1
?
h
i
=
∏
i
=
t
d
?
1
diag
?
(
σ
′
(
W
i
h
i
?
1
)
)
(
W
i
)
T
\prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T}
∏i=td?1??hi?hi+1?=∏i=td?1?diag(σ′(Wihi?1))(Wi)T 的元素值是
d
?
t
d-t
d?t 个小数值的乘积。
梯度消失的问题:
- 梯度值变成
0
0
0
- 训练没有进展
- 对于底部层尤为严重
2. 模型初始化
2.1 让训练更加稳定
- 目标:让梯度值在合理的范围内
- 例如
[
1
e
?
6
,
1
e
3
]
[1e-6, 1e3]
[1e?6,1e3]
- 将乘法变加法
- 归一化
- 合理的权重初始和激活函数
2.2 让每层的方差是一个常数
- 将每层的输出和梯度都看作随机变量
- 让它们的均值和方差都保持一致
2.3 权重初始化
- 在合理值区间里随机初始参数
- 训练开始的时候更容易有数值不稳定
- 远离最优解的地方损失函数表面可能很复杂
- 最优解附近表面会比较平
- 使用
N
(
0
,
0.01
)
N(0, 0.01)
N(0,0.01) 来初始可能对小网络没问题,但不能保证深度神经网络。
2.4 Xavier 初始化
Xavier 初始化从均值为零,方差
σ
2
=
2
n
i
n
+
n
o
u
t
\sigma^2 = \frac{2}{n_\mathrm{in} + n_\mathrm{out}}
σ2=nin?+nout?2? 的高斯分布中采样权重。 我们也可以利用 Xavier 的直觉来选择从均匀分布中抽取权重时的方差。 注意均匀分布
U
(
?
a
,
a
)
U(-a, a)
U(?a,a) 的方差为
a
2
3
\frac{a^2}{3}
3a2?。 将
a
2
3
\frac{a^2}{3}
3a2? 代入到
σ
2
\sigma^2
σ2 的条件中,将得到初始化值域:
U
(
?
6
n
i
n
+
n
o
u
t
,
6
n
i
n
+
n
o
u
t
)
.
U\left(-\sqrt{\frac{6}{n_\mathrm{in} + n_\mathrm{out}}}, \sqrt{\frac{6}{n_\mathrm{in} + n_\mathrm{out}}}\right).
U(?nin?+nout?6?
?,nin?+nout?6?
?). 尽管在上述数学推理中,“不存在非线性”的假设在神经网络中很容易被违反, 但 Xavier 初始化方法在实践中被证明是有效的。
Xavier 初始化表明,对于每一层,输出的方差不受输入数量的影响,任何梯度的方差不受输出数量的影响。
|