IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 强化学习(四):Prioritized Replay DQN、Dueling DQN,附源码解读 -> 正文阅读

[人工智能]强化学习(四):Prioritized Replay DQN、Dueling DQN,附源码解读

强化学习(四):Prioritized Replay DQN、Dueling DQN,附源码解读

本次将带来另外两个DQN算法的变种,Prioritized Replay DQN和Dueling DQN;

1 Prioritized Replay DQN

之前的DQN算法中在经验回访中利用的是均匀分布采样,而这种方式看上去并不高效,对于智能体而言,这些数据的重要程度并不一样,因此提出优先回放(Prioritized Replay)的方法。优先回放的基本思想就是打破均匀采样,赋予学习效率高的样本以更大的采样权重。

一个理想的标准是智能体学习的效率越高,权重越大。符合该标准的一个选择是TD偏差δ。TD偏差越大,说明该状态处的值函数与TD目标的差距越大,智能体的更新量越大,因此该处的学习效率越高。

优先回放DQN主要有三点改变:

(1)为了方便优先回放存储与及采样,采用sumTree树来存储;
(2)目标函数在计算时根据样本的TD偏差添加了权重(权重和TD偏差有关,偏差越大,权重越大):

在这里插入图片描述
(3)在这里TD误差的计算:

在这里插入图片描述
因此每次更新Q网络参数时,都需要重新计算TD误差。

算法流程:

在这里插入图片描述
sumTree

sumTree的结构如图示,其中序号表示节点的索引值index,从0开始。要存储的优先级放在树的叶结点上,也就是最下面一层,其中叶节点同时还存储<s、a、r、s_>等信息。叶节点上面的父节点存储的是左右子节点的优先级值之和。

在这里插入图片描述
叶节点的个数等于其前面所有曾的节点总数减1,即设叶节点个数为capacity,则整棵树的节点总数为2×capacity-1;

在sumTree中存在一个data结构,其存储的就是<s、a、r、s_>的信息。

在这里插入图片描述
存priority值的时候是从叶节点开始的,索引从capacity-1开始;比如上面的图中capacity=8,则capacity-1=7为第一个叶节点的索引值。

设置一个write来对应p值对应的data信息存储,每存一个信息,write就加1,所以p值的data信息索引值即为write,并且p值的索引值与叶子节点的索引值之差为capacity-1。

父节点的索引值:每两个子节点对应一个父节点,所以父节点的索引值=(左边子节点的索引值-1)/ 2;

每添加一个p值,整棵树都需要更新一下与这个p值有关的所有父节点;

定义一个根据数字s来描述采样节点的算法:从根节点开始比较,即index=0,如果左边的子节点的p值比s大,则走左边子节点这条,如果左边子节点小于s,则走右子节点,但s值要减去左子节点的p值,按照这个规则,一直找到叶结点,返回其索引,以及对应的q值,还有对应的data。

2 Dueling DQN

Dueling DQN尝试通过优化神经网络的结构来优化算法,将Q网络分成两部分,第一部分与状态s有关,而与具体要采用的动作无关,这部分叫做价值函数部分,第二部分同时与状态s和动作a有关,这部分叫做优势函数部分,最终的价值函数表示:

价值函数:

在这里插入图片描述
优势函数:

在这里插入图片描述
其中,w是公共部分的网络参数,α、β分别是价值函数和优势函数独有的网络参数。

最终的价值函数表示:

在这里插入图片描述
在Dueling DQN中,我们在后面加了两个子网络结构,分别对应上面上到价格函数网络部分和优势函数网络部分。对应上面右图所示。最终Q网络的输出由价格函数网络的输出和优势函数网络的输出线性组合得到。

在这里插入图片描述
并且为了最终能够辨识出输出里面的价值函数和优势函数各自的作用,实际使用的价值函数公式如下:

在这里插入图片描述
对优势函数部分做了去中心化的处理。其余部分与DQN算法无差异。

源码解读

其中Dueling DQN代码跟之前的DQN代码差不多,在此不做分析;

这里主要讲一下Prioritized Replay DQN代码中的sumTree部分,莫烦大佬这里写的不是特别好理解,我结合了网上的概念以及github上的源码进行理解;

首先是sumTree中,其中的data_pointer就是前面理论部分里的write,指的就是叶节点的位置;

__init__函数定义了tree和data,data是叶节点的,用于存储数据的,tree是整个树的节点总数,代码中叶节点有10000个,所以tree的数据是19999个;

add函数:

tree_idx计算的是叶节点的”坐标“;在前面理论部分可以看到,叶节点的"坐标"是与write差capacity-1的,而write就是代码中的data_pointer;

        self.data[self.data_pointer] = data  # update data_frame,更新数据
        self.update(tree_idx, p)  # update tree_frame,更新叶子节点和父节点的优先级

update函数:

        change = p - self.tree[tree_idx]
        self.tree[tree_idx] = p  #叶子节点的优先级更新
        # then propagate the change through tree
        while tree_idx != 0:    # this method is faster than the recursive loop in the reference code
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change   #父节点的优先级更新

这里的change是更新了叶节点的优先级p值后,其父节点的p值也要更新,所以就是父节点原本的p值+change;

tree_idx = (tree_idx - 1) // 2 计算的就是父节点,前面理论部分有介绍;

get_leaf函数:

该函数返回的是数据data、叶节点的索引值index以及对应的优先级p值;

这里是一个从顶点向叶节点的搜索过程,可以结合前面的理论部分理解这段代码

        parent_idx = 0
        while True:     # the while loop is faster than the method in the reference code
            cl_idx = 2 * parent_idx + 1         # this leaf's left and right kids,cl_idx是左边的子节点,cr_idx是右边的子节点,是index
            cr_idx = cl_idx + 1
            if cl_idx >= len(self.tree):        # reach bottom, end search,如果搜到到了叶子节点则结束搜索
                leaf_idx = parent_idx
                break
            else:       # downward search, always search for a higher priority node,向下搜索总是搜索优先级高的节点,就是跟左边的优先级比较,如果小于左边子节点的优先级,则走左边,否则走右边;
                if v <= self.tree[cl_idx]:
                    parent_idx = cl_idx
                else:
                    v -= self.tree[cl_idx]
                    parent_idx = cr_idx

        data_idx = leaf_idx - self.capacity + 1   #获得叶子节点的index

total_p函数具体作用还是很迷,它的值好像一直是10000?就是说是叶节点的总数。

然后是Memory,这一部分比sumTree的理解要难一些;

刚开始定义几个值,然后__init__函数就是建立树sumTree;

store函数:

        max_p = np.max(self.tree.tree[-self.tree.capacity:])   #从叶节点中找
        if max_p == 0:
            max_p = self.abs_err_upper
        self.tree.add(max_p, transition)   # set the max p for new p,transition是数据

讲一下self.tree.tree[-self.tree.capacity:],为什么是-self.tree.capacity,因为这一行代码找的是叶节点,但是tree中不仅包含了叶节点,还有其父节点,比如有19999个节点,叶节点的个数是10000个,那么按照存储规则来看的话叶节点是最后10000个数据,所以是-self.tree.capacity,意思就是从最后一个数据开始找,从最后往前开始数10000个数据,就都是叶节点的数据了;

sample函数:

        b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))   #产生的数据是随机的
        pri_seg = self.tree.total_p / n       # priority segment,优先级分割,因为只需要32个数据,均分32组,每组选一个数据
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])  # max = 1

        min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p     # for later calculate ISweight,便于后续计算ISweight
        for i in range(n):
            a, b = pri_seg * i, pri_seg * (i + 1)
            v = np.random.uniform(a, b)    #应该是每组的优先级的意思
            idx, p, data = self.tree.get_leaf(v)
            prob = p / self.tree.total_p
            ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
            b_idx[i], b_memory[i, :] = idx, data
        return b_idx, b_memory, ISWeights   #b_idx是叶子节点的index,b_memory是b_idx对应的叶子节点中的数据,ISWeights就是权重

第二行代码pri_seg = self.tree.total_p / n,这里就涉及到了sumTree取数据了。比如我要取两个数据,整体的优先级p值是100,那么我就要从[0,50)中取一个数据,再从[50,100)中取一个数据,就是说先均匀分割整个树,分割的数量按照你需要的数据数量来,然后再从每一个分割后的树中按p值选取一个数据;这就是这一步的作用,就是分割的作用;

后面的代码就没细看的,有的跟计算权重ISWeight有关;

batch_sample函数的作用是选取完数据后更新tree中的优先级p值,当然除了叶节点之外其父节点也要对应更新。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-04 11:12:44  更:2021-08-04 11:13:41 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 15:03:54-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码