IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 深度强化学习(DRL)五:优先回放DQN(Prioritized experience replay) -> 正文阅读

[Python知识库]深度强化学习(DRL)五:优先回放DQN(Prioritized experience replay)

全部代码

https://github.com/ColinFred/Reinforce_Learning_Pytorch/tree/main/RL/DQN

一、优先回放

在经验回放中是利用均匀分布采样,而这种方式看上去并不高效,对于智能体而言,这些数据的重要程度并不一样,因此提出优先回放(Prioritized Replay)的方法。优先回放的基本思想就是打破均匀采样,赋予学习效率高的样本以更大的采样权重。

一个理想的标准是智能体学习的效率越高,权重越大。符合该标准的一个选择是TD偏差δ。TD偏差越大,说明该状态处的值函数与TD目标的差距越大,智能体的更新量越大,因此该处的学习效率越高。

简而言之,就是在原来的replay buffer中给每个Transition增加了抽样的优先级(priority)

优先回放DQN主要有三点改变:

1, 为了方便优先回放存储与及采样,采用sumTree树来存储;

原文有两种方法计算样本抽样概率:proportional priority和rank-based priority。proportional priority就是样本被sample到的概率是正比于TD偏差的priority;rank-based priority就是概率正比于Transition priority的排序(rank)。这里考虑proportional priority,Transition被抽到的概率与TD偏差成正比。

并且,为保证每一个存入的Transition都能被sample到,新Transition会被赋予一个很大的priority。

2, 目标函数在计算时根据样本的TD偏差添加了权重(权重和TD偏差有关,偏差越大,权重越大):
1 m ∑ j = 1 m w j ( y j ? Q ( s j , a j , w ) ) 2 \frac{1}{m}\sum\limits_{j=1}^m w_j (y_j-Q(s_j, a_j, w))^2 m1?j=1m?wj?(yj??Q(sj?,aj?,w))2

3,每次更新Q网络参数时,都需要重新计算TD误差 δ j = y j ? Q ( s j , a j , w ) \delta_j = y_j- Q(s_j, a_j, w) δj?=yj??Q(sj?,aj?,w)

二、代码

Prioritized experience replay 结合之前的 Double DQN 和 Dueling DQN

SumTree和ReplayMemory_Per

SumTree主要实现:add()添加experience;get()按priority抽样;update()更新某个Transition的priority。

ReplayMemory_Per主要实现:push()插入新experience;sample()按priority抽样Transition;update()更新已有经验的priority


class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.n_entries = 0

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1

    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])


class ReplayMemory_Per(object):
    # stored as ( s, a, r, s_ ) in SumTree
    def __init__(self, capacity=1000, a=0.6, e=0.01):
        self.tree = SumTree(capacity)
        self.memory_size = capacity
        self.prio_max = 0.1
        self.a = a
        self.e = e

    def push(self, *args):
        data = Transition(*args)
        p = (np.abs(self.prio_max) + self.e) ** self.a  # proportional priority
        self.tree.add(p, data)

    def sample(self, batch_size):
        idxs = []
        segment = self.tree.total() / batch_size
        sample_datas = []

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = uniform(a, b)
            idx, p, data = self.tree.get(s)

            sample_datas.append(data)
            idxs.append(idx)
        return idxs, sample_datas

    def update(self, idxs, errors):
        self.prio_max = max(self.prio_max, max(np.abs(errors)))
        for i, idx in enumerate(idxs):
            p = (np.abs(errors[i]) + self.e) ** self.a
            self.tree.update(idx, p)

    def size(self):
        return self.tree.n_entries

每次更新Q网络参数时,都需要重新计算TD误差,并且更新SumTree。

关于目标函数在计算时根据样本的TD偏差添加了权重这一点并未采用



class PerDQN:
    def __init__(self, n_action, n_state, learning_rate):

        self.n_action = n_action
        self.n_state = n_state

        self.memory = ReplayMemory_Per(capacity=100)
        self.memory_counter = 0

        self.model_policy = DNN(self.n_state, self.n_action)
        self.model_target = DNN(self.n_state, self.n_action)
        self.model_target.load_state_dict(self.model_policy.state_dict())
        self.model_target.eval()

        self.optimizer = optim.Adam(self.model_policy.parameters(), lr=learning_rate)

    def store_transition(self, s, a, r, s_):
        state = torch.FloatTensor([s])
        action = torch.LongTensor([a])
        reward = torch.FloatTensor([r])
        next_state = torch.FloatTensor([s_])
        self.memory.push(state, action, next_state, reward)

    def choose_action(self, state):
        state = torch.FloatTensor(state)
        if np.random.randn() <= EPISILO:  # greedy policy
            with torch.no_grad():
                q_value = self.model_policy(state)
                action = q_value.max(0)[1].view(1, 1).item()
        else:  # random policy
            action = torch.tensor([randrange(self.n_action)], dtype=torch.long).item()

        return action

    def learn(self):
        if self.memory.size() < BATCH_SIZE:
            return
        idxs, transitions = self.memory.sample(BATCH_SIZE)
        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action).unsqueeze(1)
        reward_batch = torch.cat(batch.reward)
        next_state_batch = torch.cat(batch.next_state)

        state_action_values = self.model_policy(state_batch).gather(1, action_batch)

        next_action_batch = torch.unsqueeze(self.model_policy(next_state_batch).max(1)[1], 1)
        next_state_values = self.model_target(next_state_batch).gather(1, next_action_batch)
        expected_state_action_values = (next_state_values * GAMMA) + reward_batch.unsqueeze(1)

        td_errors = (state_action_values - expected_state_action_values).detach().squeeze().tolist()
        self.memory.update(idxs, td_errors)  # update td error
        loss = F.mse_loss(state_action_values, expected_state_action_values)

        self.optimizer.zero_grad()
        loss.backward()
        for param in self.model_policy.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

    def update_target_network(self):
        self.model_target.load_state_dict(self.model_policy.state_dict())

参考

  1. https://zhuanlan.zhihu.com/p/128176891
  2. https://www.cnblogs.com/jiangxinyang/p/10112381.html
  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-03-10 22:27:16  更:2022-03-10 22:27:51 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 21:15:04-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计