IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> SAC(Soft Actor Critic)学习记录 -> 正文阅读

[人工智能]SAC(Soft Actor Critic)学习记录

SAC(Soft Actor Critic)学习记录

基本介绍

SAC(Soft Actor Critic)算法在近年来受到了许多的关注,得到了不少深度强化学习研究者的好评。这篇文章主要包含的内容有SAC算法的理论分析和核心代码实现。

与许多目的是最大化累计奖励的深度强化学习算法不同,SAC算法的目的是最大化最大化熵正则化的累积奖励,这样能够鼓励智能体有更多的探索,从而达到更好的训练效果。
m a x π θ [ ∑ t γ t ( r ( S t , A t ) + α H ( π θ ( ? ∣ S t ) ) ) ] {max}_{\pi_{\theta}}\left[\sum_{t}\gamma^{t}\left(r(S_{t},A_{t})+\alpha\mathcal{H}(\pi_{\theta}(\cdot|S_{t}))\right)\right] maxπθ??[t?γt(r(St?,At?)+αH(πθ?(?St?)))]
SAC算法的目的是寻找到一个随机策略
π ? ? = arg ? m a x π θ ∑ t ∣ E ( s t , α t ) ~ ρ π θ ? [ r ( s t , α t ) + ? α H ( π θ ( ? ∣ s t ) ) ] \pi^{\ast}\,=\arg{max}_{\pi_{\theta}}\sum_{t}\vert\mathrm{E}_{(\mathrm{s}_{t},\alpha_{t})\sim\rho\pi_{\theta}}\,\left[r(\mathrm{s}_{t},\alpha_{t})+\,\alpha\mathcal{H}(\pi_\theta(\cdot\vert\mathrm{s}_{t}))\right] π?=argmaxπθ??t?E(st?,αt?)ρπθ??[r(st?,αt?)+αH(πθ?(?st?))]
一般而言我们定义V和Q的关系为
V ^ ? π θ ( s t ) ≡ ? ∣ E a t ~ π θ ( . ∣ s t ) ? [ Q ^ ? π θ ( s t , a t ) ] \begin{array}{l}{{\hat{V}_{\phi}^{\pi}\theta\left(\mathbf{s}_{t}\right)\equiv\ |\mathbf{E}_{\mathbf{a}_{t}}\sim\pi_{\theta}(.|\mathbf{s}_{t})\ \left[\hat{Q}_{\phi}^{\pi}\theta\left(\mathbf{s}_{t},\mathbf{a}_{t}\right)\right]}}\end{array} V^?π?θ(st?)?Eat??πθ?(.st?)?[Q^??π?θ(st?,at?)]?
在SAC中我们使用soft update
V ^ ? π θ ( s t ) = E a t ~ π θ ( . ∣ s t ) [ Q ^ ? π θ ( s t , a t ) ] + α H ( π θ ( . ∣ s t ) ) = E a t ~ π θ ( . ∣ s t ) [ Q ^ ? π θ ( s t , a t ) ] + α E a t ~ π θ ( . ∣ s t ) [ ? log ? π θ ( a t ∣ s t ) ] = E a t ~ π θ ( . ∣ s t ) [ Q ^ ? π θ ( s t , a t ) ? α log ? π θ ( a t ∣ s t ) ] \begin{aligned} \hat{V}_{\phi}^{\pi_{\boldsymbol{\theta}}}\left(\mathbf{s}_{t}\right) &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right]+\alpha \mathcal{H}\left(\pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)\right) \\ &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right]+\alpha \mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[-\log \pi_{\boldsymbol{\theta}}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right] \\ &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\alpha \log \pi_{\boldsymbol{\theta}}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right] \end{aligned} V^?πθ??(st?)?=Eat?πθ?(.st?)?[Q^??πθ??(st?,at?)]+αH(πθ?(.st?))=Eat?πθ?(.st?)?[Q^??πθ??(st?,at?)]+αEat?πθ?(.st?)?[?logπθ?(at?st?)]=Eat?πθ?(.st?)?[Q^??πθ??(st?,at?)?αlogπθ?(at?st?)]?
SAC有两个版本,第一版使用了Q network, V network,Policy network,熵正则化的系数为定值。第二版的SAC中将V network取消,使用了Double Q network,并且提出了能够动态调节熵正则化系数的方法。这里将先介绍第一种SAC算法,再介绍第二种SAC算法。

SAC(版本一)

V network的目标函数
J V ( ψ ) = E s t ?? ~ ?? D ?? [ 1 2 ( V ψ ( s t ) ? E a t ~ π ? [ Q θ ( s t , a t ) ? log ? π ? ( a t ∣ s t ) ] ) 2 ] J_{V}(\psi)=\mathbb{E}_{\mathbf{s}_{t}}\!\sim\!D\;\left[{\frac{1}{2}}\left(V_{\psi}(\mathbf{s}_{t})-\mathbb{E}_{\mathbf{a}_{t}\sim\pi_{\phi}}\left[Q_{\theta}(\mathbf{s}_{t},\mathbf{a}_{t})-\log\pi_{\phi}(\mathbf{a}_{t}|\mathbf{s}_{t})\right]\right)^{2}\right] JV?(ψ)=Est??D[21?(Vψ?(st?)?Eat?π???[Qθ?(st?,at?)?logπ??(at?st?)])2]
Q network的目标函数
J Q ( θ ) = E ( s t , a t ) ~ D [ 1 2 ( Q θ ( s t , a t ) ? Q ^ ( s t , a t ) ) 2 ] J_{Q}(\theta)=\mathbb{E}_{\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right) \sim \mathcal{D}}\left[\frac{1}{2}\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\hat{Q}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right)^{2}\right] JQ?(θ)=E(st?,at?)D?[21?(Qθ?(st?,at?)?Q^?(st?,at?))2]
Policy network的目标函数
J π ( ? ) = E S t ~ D [ D K L ( π ? ( ? ∣ s t ) ∣ ∣ exp ? ( Q θ ( s t , ? ) ) Z θ ( s t ) ) ] J_{\pi}(\phi)=\mathbb{E}_{\mathbb{S}_{t}\sim D}\left[\mathrm{D}_{\mathrm{KL}}\left(\pi_{\phi}(\cdot|\mathbf{s}_{t})\left|\right|{\frac{\exp\left(Q_{\theta}(\mathbf{s}_{t},\cdot)\right)}{Z_{\theta}(\mathbf{s}_{t})}}\right)\right] Jπ?(?)=ESt?D?[DKL?(π??(?st?)Zθ?(st?)exp(Qθ?(st?,?))?)]
初看Policy network的目标函数的表示可能会有些不太理解,其实 exp ? ( Q θ ( s t , ? ) ) Z θ ( s t ) \frac{\exp\left(Q_{\theta}(\mathbf{s}_{t},\cdot)\right)}{Z_{\theta}(\mathbf{s}_{t})} Zθ?(st?)exp(Qθ?(st?,?))?是下面的式子的解(其中 Z θ ( s t ) Z_{\theta}(\mathbf{s}_{t}) Zθ?(st?)用于归一化, Z ( s ) = ∑ a exp ? ( 1 α Q ( s , a ) ) Z(s)=\sum_{a}\exp\left({\textstyle{\frac{1}{\alpha}}}Q(s,a)\right) Z(s)=a?exp(α1?Q(s,a)))
π ? ? = arg ? m a x π θ ∑ t ∣ E ( s t , α t ) ~ ρ π θ ? [ r ( s t , α t ) + ? α H ( π θ ( ? ∣ s t ) ) ] \pi^{\ast}\,=\arg{max}_{\pi_{\theta}}\sum_{t}\vert\mathrm{E}_{(\mathrm{s}_{t},\alpha_{t})\sim\rho\pi_{\theta}}\,\left[r(\mathrm{s}_{t},\alpha_{t})+\,\alpha\mathcal{H}(\pi_\theta(\cdot\vert\mathrm{s}_{t}))\right] π?=argmaxπθ??t?E(st?,αt?)ρπθ??[r(st?,αt?)+αH(πθ?(?st?))]
如果采用的策略模型无法表达最优的策略π,我们可以让它们的KL散度最小。

SAC(版本二)

在SAC版本一中,使用了三个网络。但是其实V network和Q network本身是有联系的,所以后面在SAC第二个版本的提出中去掉了V network,使用了Double Q network来解决高估问题。并且提供了动态调节 α \alpha α的方法。一般来说,推荐使用第二个版本的SAC算法。版本二的SAC在很多方面都和SAC相似,本文重点介绍不同的方面。

自动化调节正则化参数的方法可以通过最下化下面的损失函数来实现其中 k = ? d i m ( A ) k=-dim(A) k=?dim(A)
J ( α ) = E a ~ π θ [ ? α log ? π θ ( a ∣ s ) ? α κ ] J(\alpha)=\mathbb{E}_{a\sim\pi_{\theta}}\left[-\alpha\log\pi_{\theta}(a|s)-\alpha\kappa\right] J(α)=Eaπθ??[?αlogπθ?(as)?ακ]
具体的证明有兴趣的读者可以参考SAC的论文

重参数化(Re-parameterization)

重参数化能够降低期望估计的方差并且有利于梯度的反向传播,在SAC中使用了重参数化的技巧。假设我们已经知道了动作的均值和标准差 μ θ \mu_{\theta} μθ? σ θ \sigma_{\theta} σθ?,我们需要令
a t = t a n h ( μ θ + ? ? σ θ ) , ? ~ N ( 0 , 1 ) a_t = tanh(\mu_{\theta}+\epsilon\cdot\sigma_{\theta}),\epsilon\sim\mathcal{N}(0,1)\qquad at?=tanh(μθ?+??σθ?),?N(0,1)
对应的Python代码为

from torch.distributions import Normal
normal = Normal(mean, std)
z = normal.rsample()

在Pytorch中Normal有samplersample,sample是直接在定义的分布上采样,rsample是先对标准正太分布N(0,1)进行采样,然后输出:mean+std×采样值,要做重参数化推荐使用rsample。根据我个人的经验,一开始我使用的是sample但是智能体并没有很好的学习到策略,换成了rsample之后很快就完成了训练。

代码示例

Policy network

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action=1, init_w=3e-3):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3_mean = nn.Linear(128, action_dim)
        self.log_std_linear = nn.Linear(128, action_dim)
        self.max_action = max_action

        self.l3_mean.weight.data.uniform_(-init_w, init_w)
        self.l3_mean.bias.data.uniform_(-init_w, init_w)
        self.log_std_linear.weight.data.uniform_(-init_w, init_w)
        self.log_std_linear.bias.data.uniform_(-init_w, init_w)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        mean = self.l3_mean(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, -20, 2)

        return mean, log_std

    def evaluate(self, state, epsilon=1e-6):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        z = normal.rsample()
        action = torch.tanh(z)
        log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)

        return action, log_prob

    def select_action(self, state):
        state = torch.FloatTensor(state).to(device)
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        z = normal.rsample()
        action = torch.tanh(z)
        action = action.detach().cpu().numpy()
        return action

注意,在训练中,我将环境内智能体的action范围进行了normalization,所以max_action=1。

 log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)

代码中的log_prob对应的是 log ? π ( a ∣ s ) \log\pi(\mathbf{a}|\mathbf{s}) logπ(as),这行代码的理论依据为论文原文的这个公式,epsilon的添加是为了避免第二项出现无穷小。

log ? π ( a ∣ s ) = log ? μ ( u ∣ s ) ? ∑ i = 1 D l o g ( 1 ? tanh ? 2 ( u i ) ) \log\pi(\mathbf{a}|\mathbf{s})=\log\mu(\mathbf{u}|\mathbf{s})-\sum_{i=1}^{D}\mathbf{log}\left(1-\operatorname{tanh}^{2}(\mathbf{u}_{i})\right) logπ(as)=logμ(us)?i=1D?log(1?tanh2(ui?))

Q network

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, init_w=3e-3):
        super(Critic, self).__init__()

        self.l1 = nn.Linear(state_dim + action_dim, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, 1)

        self.l3.weight.data.uniform_(-init_w, init_w)
        self.l3.bias.data.uniform_(-init_w, init_w)

    def forward(self, x, u):
        x = F.relu(self.l1(torch.cat([x, u], 1)))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x

这部分和以前接触的Q network的定义并没有太多的不同

Update parameters

def update(self):

        # Sample replay buffer
        state, action, reward, next_state, done = self.replay_buffer.sample(args.batch_size)
        state = torch.FloatTensor(state).to(device)
        action = torch.FloatTensor(action).to(device)
        reward = torch.FloatTensor(reward).to(device)
        next_state = torch.FloatTensor(next_state).to(device)
        done = torch.FloatTensor(1 - done).to(device)

        next_action, next_log_prob = self.policy_network.evaluate(next_state)
        # Compute the target Q value
        target_Q_1 = self.critic_target_1(next_state, next_action)
        target_Q_2 = self.critic_target_2(next_state, next_action)
        target_Q = torch.min(target_Q_1, target_Q_2) - next_log_prob
        my_target_Q = reward.reshape((100, 1)) + (done * args.gamma * target_Q)

        # Get current Q estimate
        current_Q_1 = self.critic_1(state, action)
        current_Q_2 = self.critic_2(state, action)

        # Compute critic loss
        critic_loss_1 = F.mse_loss(current_Q_1, my_target_Q.detach())
        critic_loss_2 = F.mse_loss(current_Q_2, my_target_Q.detach())
        critic_loss = critic_loss_1 + critic_loss_2
        
        # Optimize the critic
        self.critic_optimizer_1.zero_grad()
        self.critic_optimizer_2.zero_grad()
        critic_loss.backward()
        self.critic_optimizer_1.step()
        self.critic_optimizer_2.step()

        if self.update_step % 2 == 0:
            new_action, log_prob = self.policy_network.evaluate(state)
            # Compute actor loss
            min_q = torch.min(
                self.critic_1(state, new_action),
                self.critic_2(state, new_action)
            )
            actor_loss = (log_prob - min_q).mean()
            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic_1.parameters(), self.critic_target_1.parameters()):
                target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
            for param, target_param in zip(self.critic_2.parameters(), self.critic_target_2.parameters()):
                target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
        self.update_step += 1

参数更新主要分为三个部分,第一个部分为Q network,第二部分为 Policy network, 第三部分为 α \alpha α。在上述代码中我没有实现第三部分的更新,读者如果想实现自动调节只需根据公式完成代码的编写即可。

Reference

1:Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor

2:Soft Actor-Critic Algorithms and Applications

3:Deep Reinforcement Learning Fundamentals, Research and Applications

4:From Policy Gradient to Actor-Critic methods Soft Actor Critic, ISIR

5:https://github.com/cyoon1729/Policy-Gradient-Methods

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-06 23:10:15  更:2022-04-06 23:12:31 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:19:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码