在训练模型的过程中,有时候我们需要设置两个模型,一个是随着训练数据的加入进行参数更新的model,另一个模型是作为model的平均模型model_ema,对model_ema进行更新的时候,采用的方式为:
θ
t
m
o
d
e
l
_
e
m
a
=
β
θ
t
?
1
m
o
d
e
l
_
e
m
a
+
(
1
?
β
)
θ
t
m
o
d
e
l
\theta_t^{model\_ema} =\beta\theta_{t-1}^{model\_ema}+(1-\beta)\theta_t^{model}
θtmodel_ema?=βθt?1model_ema?+(1?β)θtmodel? pytorch 的实现如下所示:
def update_model_ema(model, ema_model, alpha):
model_state = model.state_dict()
model_ema_state = ema_model.state_dict()
new_dict = {}
for key in model_state:
new_dict[key] = alpha * model_ema_state[key] + (1 - alpha) * model_state[key]
ema_model.load_state_dict(new_dict)
|