SAC(Soft Actor Critic)学习记录
基本介绍
SAC(Soft Actor Critic)算法在近年来受到了许多的关注,得到了不少深度强化学习研究者的好评。这篇文章主要包含的内容有SAC算法的理论分析和核心代码实现。
与许多目的是最大化累计奖励的深度强化学习算法不同,SAC算法的目的是最大化最大化熵正则化的累积奖励,这样能够鼓励智能体有更多的探索,从而达到更好的训练效果。
m
a
x
π
θ
[
∑
t
γ
t
(
r
(
S
t
,
A
t
)
+
α
H
(
π
θ
(
?
∣
S
t
)
)
)
]
{max}_{\pi_{\theta}}\left[\sum_{t}\gamma^{t}\left(r(S_{t},A_{t})+\alpha\mathcal{H}(\pi_{\theta}(\cdot|S_{t}))\right)\right]
maxπθ??[t∑?γt(r(St?,At?)+αH(πθ?(?∣St?)))] SAC算法的目的是寻找到一个随机策略
π
?
?
=
arg
?
m
a
x
π
θ
∑
t
∣
E
(
s
t
,
α
t
)
~
ρ
π
θ
?
[
r
(
s
t
,
α
t
)
+
?
α
H
(
π
θ
(
?
∣
s
t
)
)
]
\pi^{\ast}\,=\arg{max}_{\pi_{\theta}}\sum_{t}\vert\mathrm{E}_{(\mathrm{s}_{t},\alpha_{t})\sim\rho\pi_{\theta}}\,\left[r(\mathrm{s}_{t},\alpha_{t})+\,\alpha\mathcal{H}(\pi_\theta(\cdot\vert\mathrm{s}_{t}))\right]
π?=argmaxπθ??t∑?∣E(st?,αt?)~ρπθ??[r(st?,αt?)+αH(πθ?(?∣st?))] 一般而言我们定义V和Q的关系为
V
^
?
π
θ
(
s
t
)
≡
?
∣
E
a
t
~
π
θ
(
.
∣
s
t
)
?
[
Q
^
?
π
θ
(
s
t
,
a
t
)
]
\begin{array}{l}{{\hat{V}_{\phi}^{\pi}\theta\left(\mathbf{s}_{t}\right)\equiv\ |\mathbf{E}_{\mathbf{a}_{t}}\sim\pi_{\theta}(.|\mathbf{s}_{t})\ \left[\hat{Q}_{\phi}^{\pi}\theta\left(\mathbf{s}_{t},\mathbf{a}_{t}\right)\right]}}\end{array}
V^?π?θ(st?)≡?∣Eat??~πθ?(.∣st?)?[Q^??π?θ(st?,at?)]? 在SAC中我们使用soft update
V
^
?
π
θ
(
s
t
)
=
E
a
t
~
π
θ
(
.
∣
s
t
)
[
Q
^
?
π
θ
(
s
t
,
a
t
)
]
+
α
H
(
π
θ
(
.
∣
s
t
)
)
=
E
a
t
~
π
θ
(
.
∣
s
t
)
[
Q
^
?
π
θ
(
s
t
,
a
t
)
]
+
α
E
a
t
~
π
θ
(
.
∣
s
t
)
[
?
log
?
π
θ
(
a
t
∣
s
t
)
]
=
E
a
t
~
π
θ
(
.
∣
s
t
)
[
Q
^
?
π
θ
(
s
t
,
a
t
)
?
α
log
?
π
θ
(
a
t
∣
s
t
)
]
\begin{aligned} \hat{V}_{\phi}^{\pi_{\boldsymbol{\theta}}}\left(\mathbf{s}_{t}\right) &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right]+\alpha \mathcal{H}\left(\pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)\right) \\ &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right]+\alpha \mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[-\log \pi_{\boldsymbol{\theta}}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right] \\ &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\alpha \log \pi_{\boldsymbol{\theta}}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right] \end{aligned}
V^?πθ??(st?)?=Eat?~πθ?(.∣st?)?[Q^??πθ??(st?,at?)]+αH(πθ?(.∣st?))=Eat?~πθ?(.∣st?)?[Q^??πθ??(st?,at?)]+αEat?~πθ?(.∣st?)?[?logπθ?(at?∣st?)]=Eat?~πθ?(.∣st?)?[Q^??πθ??(st?,at?)?αlogπθ?(at?∣st?)]? SAC有两个版本,第一版使用了Q network, V network,Policy network,熵正则化的系数为定值。第二版的SAC中将V network取消,使用了Double Q network,并且提出了能够动态调节熵正则化系数的方法。这里将先介绍第一种SAC算法,再介绍第二种SAC算法。
SAC(版本一)
V network的目标函数
J
V
(
ψ
)
=
E
s
t
??
~
??
D
??
[
1
2
(
V
ψ
(
s
t
)
?
E
a
t
~
π
?
[
Q
θ
(
s
t
,
a
t
)
?
log
?
π
?
(
a
t
∣
s
t
)
]
)
2
]
J_{V}(\psi)=\mathbb{E}_{\mathbf{s}_{t}}\!\sim\!D\;\left[{\frac{1}{2}}\left(V_{\psi}(\mathbf{s}_{t})-\mathbb{E}_{\mathbf{a}_{t}\sim\pi_{\phi}}\left[Q_{\theta}(\mathbf{s}_{t},\mathbf{a}_{t})-\log\pi_{\phi}(\mathbf{a}_{t}|\mathbf{s}_{t})\right]\right)^{2}\right]
JV?(ψ)=Est??~D[21?(Vψ?(st?)?Eat?~π???[Qθ?(st?,at?)?logπ??(at?∣st?)])2] Q network的目标函数
J
Q
(
θ
)
=
E
(
s
t
,
a
t
)
~
D
[
1
2
(
Q
θ
(
s
t
,
a
t
)
?
Q
^
(
s
t
,
a
t
)
)
2
]
J_{Q}(\theta)=\mathbb{E}_{\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right) \sim \mathcal{D}}\left[\frac{1}{2}\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\hat{Q}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right)^{2}\right]
JQ?(θ)=E(st?,at?)~D?[21?(Qθ?(st?,at?)?Q^?(st?,at?))2] Policy network的目标函数
J
π
(
?
)
=
E
S
t
~
D
[
D
K
L
(
π
?
(
?
∣
s
t
)
∣
∣
exp
?
(
Q
θ
(
s
t
,
?
)
)
Z
θ
(
s
t
)
)
]
J_{\pi}(\phi)=\mathbb{E}_{\mathbb{S}_{t}\sim D}\left[\mathrm{D}_{\mathrm{KL}}\left(\pi_{\phi}(\cdot|\mathbf{s}_{t})\left|\right|{\frac{\exp\left(Q_{\theta}(\mathbf{s}_{t},\cdot)\right)}{Z_{\theta}(\mathbf{s}_{t})}}\right)\right]
Jπ?(?)=ESt?~D?[DKL?(π??(?∣st?)∣∣Zθ?(st?)exp(Qθ?(st?,?))?)] 初看Policy network的目标函数的表示可能会有些不太理解,其实
exp
?
(
Q
θ
(
s
t
,
?
)
)
Z
θ
(
s
t
)
\frac{\exp\left(Q_{\theta}(\mathbf{s}_{t},\cdot)\right)}{Z_{\theta}(\mathbf{s}_{t})}
Zθ?(st?)exp(Qθ?(st?,?))?是下面的式子的解(其中
Z
θ
(
s
t
)
Z_{\theta}(\mathbf{s}_{t})
Zθ?(st?)用于归一化,
Z
(
s
)
=
∑
a
exp
?
(
1
α
Q
(
s
,
a
)
)
Z(s)=\sum_{a}\exp\left({\textstyle{\frac{1}{\alpha}}}Q(s,a)\right)
Z(s)=a∑?exp(α1?Q(s,a)))
π
?
?
=
arg
?
m
a
x
π
θ
∑
t
∣
E
(
s
t
,
α
t
)
~
ρ
π
θ
?
[
r
(
s
t
,
α
t
)
+
?
α
H
(
π
θ
(
?
∣
s
t
)
)
]
\pi^{\ast}\,=\arg{max}_{\pi_{\theta}}\sum_{t}\vert\mathrm{E}_{(\mathrm{s}_{t},\alpha_{t})\sim\rho\pi_{\theta}}\,\left[r(\mathrm{s}_{t},\alpha_{t})+\,\alpha\mathcal{H}(\pi_\theta(\cdot\vert\mathrm{s}_{t}))\right]
π?=argmaxπθ??t∑?∣E(st?,αt?)~ρπθ??[r(st?,αt?)+αH(πθ?(?∣st?))] 如果采用的策略模型无法表达最优的策略π,我们可以让它们的KL散度最小。
SAC(版本二)
在SAC版本一中,使用了三个网络。但是其实V network和Q network本身是有联系的,所以后面在SAC第二个版本的提出中去掉了V network,使用了Double Q network来解决高估问题。并且提供了动态调节
α
\alpha
α的方法。一般来说,推荐使用第二个版本的SAC算法。版本二的SAC在很多方面都和SAC相似,本文重点介绍不同的方面。
自动化调节正则化参数的方法可以通过最下化下面的损失函数来实现其中
k
=
?
d
i
m
(
A
)
k=-dim(A)
k=?dim(A)
J
(
α
)
=
E
a
~
π
θ
[
?
α
log
?
π
θ
(
a
∣
s
)
?
α
κ
]
J(\alpha)=\mathbb{E}_{a\sim\pi_{\theta}}\left[-\alpha\log\pi_{\theta}(a|s)-\alpha\kappa\right]
J(α)=Ea~πθ??[?αlogπθ?(a∣s)?ακ] 具体的证明有兴趣的读者可以参考SAC的论文
重参数化(Re-parameterization)
重参数化能够降低期望估计的方差并且有利于梯度的反向传播,在SAC中使用了重参数化的技巧。假设我们已经知道了动作的均值和标准差
μ
θ
\mu_{\theta}
μθ?和
σ
θ
\sigma_{\theta}
σθ?,我们需要令
a
t
=
t
a
n
h
(
μ
θ
+
?
?
σ
θ
)
,
?
~
N
(
0
,
1
)
a_t = tanh(\mu_{\theta}+\epsilon\cdot\sigma_{\theta}),\epsilon\sim\mathcal{N}(0,1)\qquad
at?=tanh(μθ?+??σθ?),?~N(0,1) 对应的Python代码为
from torch.distributions import Normal
normal = Normal(mean, std)
z = normal.rsample()
在Pytorch中Normal有sample 和rsample ,sample是直接在定义的分布上采样,rsample是先对标准正太分布N(0,1)进行采样,然后输出:mean+std×采样值,要做重参数化推荐使用rsample。根据我个人的经验,一开始我使用的是sample但是智能体并没有很好的学习到策略,换成了rsample之后很快就完成了训练。
代码示例
Policy network
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action=1, init_w=3e-3):
super(Actor, self).__init__()
self.l1 = nn.Linear(state_dim, 128)
self.l2 = nn.Linear(128, 128)
self.l3_mean = nn.Linear(128, action_dim)
self.log_std_linear = nn.Linear(128, action_dim)
self.max_action = max_action
self.l3_mean.weight.data.uniform_(-init_w, init_w)
self.l3_mean.bias.data.uniform_(-init_w, init_w)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
def forward(self, x):
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
mean = self.l3_mean(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, -20, 2)
return mean, log_std
def evaluate(self, state, epsilon=1e-6):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.rsample()
action = torch.tanh(z)
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
log_prob = log_prob.sum(1, keepdim=True)
return action, log_prob
def select_action(self, state):
state = torch.FloatTensor(state).to(device)
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.rsample()
action = torch.tanh(z)
action = action.detach().cpu().numpy()
return action
注意,在训练中,我将环境内智能体的action范围进行了normalization,所以max_action=1。
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
代码中的log_prob对应的是
log
?
π
(
a
∣
s
)
\log\pi(\mathbf{a}|\mathbf{s})
logπ(a∣s),这行代码的理论依据为论文原文的这个公式,epsilon的添加是为了避免第二项出现无穷小。
log
?
π
(
a
∣
s
)
=
log
?
μ
(
u
∣
s
)
?
∑
i
=
1
D
l
o
g
(
1
?
tanh
?
2
(
u
i
)
)
\log\pi(\mathbf{a}|\mathbf{s})=\log\mu(\mathbf{u}|\mathbf{s})-\sum_{i=1}^{D}\mathbf{log}\left(1-\operatorname{tanh}^{2}(\mathbf{u}_{i})\right)
logπ(a∣s)=logμ(u∣s)?i=1∑D?log(1?tanh2(ui?))
Q network
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, init_w=3e-3):
super(Critic, self).__init__()
self.l1 = nn.Linear(state_dim + action_dim, 128)
self.l2 = nn.Linear(128, 128)
self.l3 = nn.Linear(128, 1)
self.l3.weight.data.uniform_(-init_w, init_w)
self.l3.bias.data.uniform_(-init_w, init_w)
def forward(self, x, u):
x = F.relu(self.l1(torch.cat([x, u], 1)))
x = F.relu(self.l2(x))
x = self.l3(x)
return x
这部分和以前接触的Q network的定义并没有太多的不同
Update parameters
def update(self):
state, action, reward, next_state, done = self.replay_buffer.sample(args.batch_size)
state = torch.FloatTensor(state).to(device)
action = torch.FloatTensor(action).to(device)
reward = torch.FloatTensor(reward).to(device)
next_state = torch.FloatTensor(next_state).to(device)
done = torch.FloatTensor(1 - done).to(device)
next_action, next_log_prob = self.policy_network.evaluate(next_state)
target_Q_1 = self.critic_target_1(next_state, next_action)
target_Q_2 = self.critic_target_2(next_state, next_action)
target_Q = torch.min(target_Q_1, target_Q_2) - next_log_prob
my_target_Q = reward.reshape((100, 1)) + (done * args.gamma * target_Q)
current_Q_1 = self.critic_1(state, action)
current_Q_2 = self.critic_2(state, action)
critic_loss_1 = F.mse_loss(current_Q_1, my_target_Q.detach())
critic_loss_2 = F.mse_loss(current_Q_2, my_target_Q.detach())
critic_loss = critic_loss_1 + critic_loss_2
self.critic_optimizer_1.zero_grad()
self.critic_optimizer_2.zero_grad()
critic_loss.backward()
self.critic_optimizer_1.step()
self.critic_optimizer_2.step()
if self.update_step % 2 == 0:
new_action, log_prob = self.policy_network.evaluate(state)
min_q = torch.min(
self.critic_1(state, new_action),
self.critic_2(state, new_action)
)
actor_loss = (log_prob - min_q).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
for param, target_param in zip(self.critic_1.parameters(), self.critic_target_1.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
for param, target_param in zip(self.critic_2.parameters(), self.critic_target_2.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
self.update_step += 1
参数更新主要分为三个部分,第一个部分为Q network,第二部分为 Policy network, 第三部分为
α
\alpha
α。在上述代码中我没有实现第三部分的更新,读者如果想实现自动调节只需根据公式完成代码的编写即可。
Reference
1:Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor
2:Soft Actor-Critic Algorithms and Applications
3:Deep Reinforcement Learning Fundamentals, Research and Applications
4:From Policy Gradient to Actor-Critic methods Soft Actor Critic, ISIR
5:https://github.com/cyoon1729/Policy-Gradient-Methods
|