原理
归一化公式:
y
=
x
?
E
[
x
]
V
a
r
[
x
]
+
?
?
γ
+
β
y=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon}}*\gamma + \beta
y=Var[x]+?
?x?E[x]??γ+β
其中:
-
E
[
x
]
E[x]
E[x] 是向量
x
x
x 的均值
-
V
a
r
[
x
]
Var[x]
Var[x] 是向量
x
x
x 的方差
-
?
\epsilon
? 常数,通常等于
0.00001
0.00001
0.00001,防止分母为 0
-
γ
\gamma
γ 用于仿射变换
-
β
\beta
β 用于仿射变换
本文介绍的 4 种归一化主要是针对的维度不同,例如 BatchNorm 是对所有 banch 的单个通道归一化,每个通道的归一化独立,而 GroupNorm 是一个 banch 下的通道分组归一化,不受 banch size 的影响,如下图:
1 BatchNorm
BN 是对所有 banch 的单个通道做归一化,每个通道都分别做一次。
torch.nn.BatchNorm2d(num_features, eps=1e-5, momentum=0.1,\
affine=True, track_running_stats=True)
成员变量:
- num_features:通道数。
- eps:常数
?
\epsilon
?。
- momentum:动量参数,用来控制 running_mean 和 running_var 的更新,更新方法:
M
n
e
w
=
(
1
?
m
)
?
M
o
l
d
+
m
?
m
e
a
n
M_{new}=(1-m)*M_{old}+m*mean
Mnew?=(1?m)?Mold?+m?mean,其中,
M
n
e
w
M_{new}
Mnew? 是最新的 running_mean,
M
o
l
d
M_{old}
Mold? 是上一次的 running_mean,
m
e
a
n
mean
mean 是当前批数据的均值。
- affine:仿射变换的开关
- 如果 affine=False,则
γ
=
1
\gamma=1
γ=1、
β
=
0
\beta=0
β=0,且不能学习;(对应weight、bias变量)
- 如果 affine=True,则
γ
\gamma
γ、
β
\beta
β 可以学习;
- training:训练状态或测试状态,两种状态下运行逻辑不通。
- track_running_stats:如果为 True,则统计跟踪 batch 的个数,记录在 num_batches_tracked 中。
- num_btaches_tracked:跟踪 batch 的个数。
trainning 和 tracking_running_stats 有 4 种组合:
trainning | tracking_running_stats | 说明 |
---|
True | True | 正常的训练过程,跟踪整个训练过程的 banch 特性 | True | False | 不跟踪训练过程的 banch 特性,只计算当前的 banch 统计特性 | False | True | 使用之前训练好的 running_mean、running_var,且不会更新 | False | False | (一般不采用)只计算当前特征 |
更新过程:
- running_mean、running_var 是在 forward 过程中更新的,记录在 buffer 中。(反向传播部影响)
-
γ
\gamma
γ、
β
\beta
β 是在反向传播中学习得到的。
- model.eval() 可以固定住 running_mean、running_var。
2 GroupNorm
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-5, affine=True)
3 InstanceNorm
torch.nn.InstanceNorm2d(num_features, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False)
4 LayerNorm
torch.nn.LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True)
参考上面的原理图,LayerNorm 是对一个 banch 的所有通道做归一化,如果输入的 tensor 维度为
[
4
,
6
,
3
,
3
]
[4,6,3,3]
[4,6,3,3],那么函数的传参 normalized_shape 就是
[
6
,
3
,
3
]
[6,3,3]
[6,3,3]。
|