IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> D3QN代码实现 -> 正文阅读

[人工智能]D3QN代码实现

D3QN代码实现

  • 使用tensorflow实现使用D3QN。

代码及解释

1.包引入与参数设定

import argparse
import os
import random

import numpy as np

import gym
import tensorflow as tf
import tensorlayer as tl

from matplotlib import animation
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument('--train', dest='train', default=False)
# 训练时是否渲染
parser.add_argument('--render', type=bool, default=False)
parser.add_argument('--save_gif', type=bool, default=True)

parser.add_argument('--gamma', type=float, default=0.995)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--eps', type=float, default=0.2)

parser.add_argument('--train_episodes', type=int, default=1000)
parser.add_argument('--test_episodes', type=int, default=10)
args = parser.parse_args()

ALG_NAME = 'D3QN'
ENV_ID = 'LunarLander-v2'

2.ReplayBuffer的实现

import random  
import numpy as np

class ReplayBuffer:  
	def __init__(self, capacity=50000):  
		self.capacity = capacity  
		self.buffer = []
		#buffer满了之后要从头开始循环利用
		self.position = 0  
	
	def push(self, state, action, reward, next_state, done):  
		if len(self.buffer) < self.capacity:  
			self.buffer.append(None)  
		self.buffer[self.position] = (state, action, reward, next_state, done)  
		self.position = int((self.position + 1) % self.capacity)  

	def sample(self, batch_size = args.batch_size): 
		#从buffer里随机抽batch_size个transition出来
		batch = random.sample(self.buffer, batch_size)
		#把这batch_size个transition分门别类放在几个数组里
		state, action, reward, next_state, done = map(np.stack, zip(*batch))  
		return state, action, reward, next_state, done

3.D3QN类的实现

  • D3QN类主要实现8个方法。
    • _init_:初始化agent。
    • target_update:用于更新target network。
    • choose_action:选择动作。
    • replay:使用梯度下降更新价值函数。
    • test_episode:用于测试模型。
    • train:用于采集训练模型所需要的参数。
    • saveModel:保存模型。
    • loadModel:加载模型。
3.1. _init_
  • D3QN网络建立
def create_model(input_state_shape):
    input_layer = tl.layers.Input(input_state_shape)
    layer_1 = tl.layers.Dense(n_units=256, act=tf.nn.relu)(input_layer)
    layer_2 = tl.layers.Dense(n_units=128, act=tf.nn.relu)(layer_1)

    state_hidden = tl.layers.Dense(n_units=64)(layer_2)
    adv_hidden = tl.layers.Dense(n_units=64)(layer_2)

    # state value
    state_value = tl.layers.Dense(n_units=1)(state_hidden)
    # advantage value
    adv_value = tl.layers.Dense(n_units=self.action_dim)(adv_hidden)

    mean = tl.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True))(adv_value)
    advantage = tl.layers.ElementwiseLambda(lambda x, y: x-y)([adv_value, mean])
    # output
    output_layer = tl.layers.ElementwiseLambda(lambda x, y: x+y)([state_value, advantage])
    return tl.models.Model(inputs=input_layer, outputs=output_layer)
  • _init_
def __init__(self, env):
    self.env = env
    self.state_dim = self.env.observation_space.shape[0]
    self.action_dim = self.env.action_space.n

    self.model = create_model([None, self.state_dim])
    self.target_model = create_model([None, self.state_dim])
    self.model.train()
    self.target_model.eval()
    self.model_optim  = tf.optimizers.Adam(lr=args.lr)

    self.epsilon = args.eps

    self.buffer = ReplayBuffer()
3.2. target_update
def target_update(self):  
	"""Copy q network to target q network"""  
	for weights, target_weights in zip(  
			self.model.trainable_weights, self.target_model.trainable_weights):  
		target_weights.assign(weights)
3.3. choose_action
def choose_action(self, state):
    if np.random.uniform() < self.epsilon:
        return np.random.choice(self.action_dim)
    else:
        q_value = self.model(state[np.newaxis, :])[0]
        return np.argmax(q_value)
  • np.random.uniform(low=0,high=1.0),生成随机数,默认范围是[0,1]
  • choose_action函数首先产生一个范围为[0,1]的随机数,如果随机数小于ε,则进行探索,否则使用价值函数对当前状态进行评估,选择q值最大的动作。
  • [np.newaxis, :]的作用是在np.newaxis的位置添加新的维度,在这里state是形状为(,state.dim)的向量,添加维度0后,就变成了(1,state.dim)维的向量。
  • model后面加[0]是因为此时只输入了一个state,因此结果也只返回一组动作的q_value值。
  • np.argmax的作用是找到数组中最大的数,并返回下标。
3.4. replay
  • 在replay函数中,主要完成价值网络参数的更新,也是本代码中主要使用"Cuda"计算的地方。
def replay(self):
    for _ in range(10):
        states, actions, rewards, next_states, done = self.buffer.sample()
        target = self.target_model(states).numpy()
        # next_q_values [batch_size, action_dim]
        next_target = self.target_model(next_states).numpy()
        # next_q_value [batch_size, 1]
        next_q_value = next_target[
            range(args.batch_size), np.argmax(self.model(next_states), axis=1)
        ]
        target[range(args.batch_size), actions] = rewards + (1 - done) * args.gamma * next_q_value

        # use sgd to update the network weight
        with tf.GradientTape() as tape:
            q_pred = self.model(states)
            loss = tf.losses.mean_squared_error(target, q_pred)
        grads = tape.gradient(loss, self.model.trainable_weights)
        self.model_optim.apply_gradients(zip(grads, self.model.trainable_weights))
  • D3QN使用Q网络选择动作,再用Target网络评估价值。
3.5. test_episode
  • 在test_episode函数中,对模型进行测试数次,并将每次运行的结果保存为gif文件。
def test_episode(self, test_episodes):
    for episode in range(test_episodes):
        state = self.env.reset().astype(np.float32)
        total_reward, done = 0, False
        frames = []
        while not done:
            action = self.model(np.array([state], dtype=np.float32))[0]
            action = np.argmax(action)
            next_state, reward, done, _ = self.env.step(action)
            next_state = next_state.astype(np.float32)

            total_reward += reward
            state = next_state
            frames.append(env.render(mode='rgb_array'))
        # 将本场游戏保存为gif
        if args.save_gif:
            dir_path = os.path.join('testVideo', '_'.join([ALG_NAME, ENV_ID]))
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)
            display_frames_as_gif(frames, dir_path + '\\' + str(episode) + ".gif")
        print("Test {} | episode rewards is {}".format(episode, total_reward))
  • 如何将gym运行过程保存为gif文件?
from matplotlib import animation  
import matplotlib.pyplot as plt

#第一步:定义帧画面转化为gif的函数
def display_frames_as_gif(frames, path):  
	patch = plt.imshow(frames[0])  
	plt.axis('off')  

	def animate(i):  
		patch.set_data(frames[i])  

	anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)  
	anim.save(path, writer='pillow', fps=30)
	
#第二步:定义一个frames,用于收集游戏过程中的画面
frames = []  

#第三步:在游戏运行过程中,收集画面
frames.append(self.env.render(mode = 'rgb_array'))  

#第四部:游戏运行完毕后,将frames中的内容保存为gif
dir_path = os.path.join('testVideo', '_'.join([ALG_NAME, ENV_ID]))  
if not os.path.exists(dir_path):  
	os.makedirs(dir_path)  
display_frames_as_gif(frames, dir_path + '\\' + str(episode) + ".gif")
3.6. train
def train(self, train_episodes=200):
    self.loadModel()
    if args.train:
        all_ep_r = []
        for episode in range(train_episodes):
            total_reward, done = 0, False
            state = self.env.reset().astype(np.float32)
            while not done:
                if args.render:
                    env.render()
                action = self.choose_action(state)
                next_state, reward, done, _ = self.env.step(action)
                next_state = next_state.astype(np.float32)
                reward -= 0.1
                self.buffer.push(state, action, reward, next_state, done)
                total_reward += reward
                state = next_state
                # self.render()
            if len(self.buffer.buffer) > args.batch_size:
                self.replay()
                self.target_update()

            if episode == 0:
                all_ep_r.append(total_reward)
            else:
                all_ep_r.append(all_ep_r[-1] * 0.9 + total_reward * 0.1)
            print(
                'Episode: {}/{}  | Episode Reward: {:.4f}'.format(
                    episode, args.train_episodes, total_reward
                )
            )
            # 一百轮保存一遍模型
            if episode % 100 == 0:
                self.saveModel()
	else:
        self.test_episode(test_episodes=args.test_episodes)
3.7. saveModel
def saveModel(self):
    path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))
    if not os.path.exists(path):
        os.makedirs(path)
    tl.files.save_weights_to_hdf5(os.path.join(path, 'model.hdf5'), self.model)
    tl.files.save_weights_to_hdf5(os.path.join(path, 'target_model.hdf5'), self.target_model)
    print('Saved weights.')
3.8. loadModel
def loadModel(self):
    path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))
    if os.path.exists(path):
        print('Load DQN Network parametets ...')
        tl.files.load_hdf5_to_weights_in_order(os.path.join(path, 'model.hdf5'), self.model)
        tl.files.load_hdf5_to_weights_in_order(os.path.join(path, 'target_model.hdf5'), self.target_model)
        print('Load weights!')
    else: print("No model file find, please train model first...")

4.主程序

if __name__ == '__main__':  
	env = gym.make(ENV_ID)  
	agent = D3QN(env)  
	agent.train(train_episodes=args.train_episodes)  
	env.close()

训练结果

训练1000盘后

请添加图片描述
请添加图片描述

请添加图片描述

更详细的代码解释参考:DQN with Target代码实现

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-16 19:38:58  更:2021-10-16 19:40:51 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 10:53:13-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码