IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 游戏开发 -> reinforcement_q_learning—解读 -> 正文阅读

[游戏开发]reinforcement_q_learning—解读

注意:
Torch官网reinforcement_q_learning源码
下文中不提供源码,需一边阅读一边对照源码,仅针对部分必要源码变量进行分析


原理讲解

如果不了解大致符号和理论定义,参考强化学习入门—超级马里奥

  1. 简单来说,定义一个模型,输入是整个画面,输出是每个动作带来的未来价值(维度与动作数相同)
  2. 计算两个神经网络对象,假定一个是目标网络target,一个是需学习的网络policy
  3. policy看到当前画面,作出相应动作后的总价值 U t U_t Ut?
  4. target看到下一个时刻 t + 1 t+1 t+1时的最大价值 U t + 1 U_{t+1} Ut+1?
  5. t t t t + 1 t+1 t+1相差的是时刻 t t t的奖励 R t R_t Rt?
  6. U t = γ ? U t + 1 + R t U_t=\gamma*U_{t+1}+R_t Ut?=γ?Ut+1?+Rt?
  7. 理论上的当前最大价值与当前实际价值做误差

gym

import random
import time
from itertools import count

import gym
import torch
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 环境创建
env = gym.make('CartPole-v0').unwrapped

env.reset()  # 以下的每个reset后面的内容都可以替换该部分后面
env.render()
time.sleep(5)
# 执行结果:显示出环境并停留5s

# 2. 获取当前画面
env.reset()
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
_, screen_height, screen_width = screen.shape
screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]
for i in screen:
    for j in i:
        for k in j:
            print(f"{k:3}", end=" ")
        print()
# 执行结果:打印出来一堆数字,翻找中间有为0的一行,再向右翻,能看到不同数字,代表颜色,前面环境的画面颜色相对应
# 注意:不必过分纠结gym的具体使用,总之,通过上述测试的screen,可以得到用于图像识别模型的输入

# 3. 获取action的取值范围
env.reset()
n_actions = env.action_space.n
print(n_actions)
# 执行结果:2(action的取值范围为2)

# 4. 随机移动
for i in tqdm(range(50)):
    env.reset()
    env.render()
    for _ in count():
        _, _, done, _ = env.step(torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long).item())
        env.render()
        if done:
            break
# 注意:render表示提交或者绘制,一开始先使用一次render,之后每次执行step后必须重新render来更新

# 最终必须close
env.close()

gym.make的参数位于conda相应环境的site-packages/gym/envs/__init__.py中,对应gym官网的环境

ReplayMemory

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))
# Transition并没有特殊格式,定义成一个包含相应成员变量的类也可以

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

DQN

class DQN(nn.Module):
	def __init__(self, h, w, outputs):
		...

policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
# 神经网络模型,输入:一张图,输出:维度与action的取值范围相对应
target_net.load_state_dict(policy_net.state_dict())
# target_net初始与policy_net所有参数保持一致

select_action(state)

# 选择动作:设定负指数幂函数,取值范围从1减小到0,使得,大多数情况执行2,随着时间推移更容易执行1
# 1. 根据当前画面,预测最佳动作
# 2. 随机在action取值范围选择一个
def select_action(state):
	...
# 返回值:torch.tensor([[action取值范围内的一个数]], device=device, dtype=torch.long)

optimize_model

def optimize_model():
    ...
    batch = Transition(*zip(*transitions))
    # batch = Transition(state=(state_1, state_2, ...), action=(action_1, action_2, ...), ...)
	
	# 提取出来不是None的next_state
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                       if s is not None])
    ...

    # 当前状态,执行当前动作,带来的从当前时刻t到未来的所有价值:U_t=state_action_values(选择不同action的价值)
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    # 时刻t+1到未来的所有价值为U_t+1=next_state_values(用target_net预测)
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    
    # 理论上,U_t=gamma*U_t+1+R_t
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # 理论U_t与实际U_t计算loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
	...
  游戏开发 最新文章
6、英飞凌-AURIX-TC3XX: PWM实验之使用 GT
泛型自动装箱
CubeMax添加Rtthread操作系统 组件STM32F10
python多线程编程:如何优雅地关闭线程
数据类型隐式转换导致的阻塞
WebAPi实现多文件上传,并附带参数
from origin ‘null‘ has been blocked by
UE4 蓝图调用C++函数(附带项目工程)
Unity学习笔记(一)结构体的简单理解与应用
【Memory As a Programming Concept in C a
上一篇文章      下一篇文章      查看所有文章
加:2022-04-27 11:37:15  更:2022-04-27 11:39:03 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/17 1:01:06-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码