reinforcement_q_learning—解读
注意: Torch官网reinforcement_q_learning源码 下文中不提供源码,需一边阅读一边对照源码,仅针对部分必要源码变量进行分析
原理讲解
如果不了解大致符号和理论定义,参考强化学习入门—超级马里奥
- 简单来说,定义一个模型,输入是整个画面,输出是每个动作带来的未来价值(维度与动作数相同)
- 计算两个神经网络对象,假定一个是目标网络target,一个是需学习的网络policy
- policy看到当前画面,作出相应动作后的总价值
U
t
U_t
Ut?
- target看到下一个时刻
t
+
1
t+1
t+1时的最大价值
U
t
+
1
U_{t+1}
Ut+1?
-
t
t
t与
t
+
1
t+1
t+1相差的是时刻
t
t
t的奖励
R
t
R_t
Rt?
-
U
t
=
γ
?
U
t
+
1
+
R
t
U_t=\gamma*U_{t+1}+R_t
Ut?=γ?Ut+1?+Rt?
- 理论上的当前最大价值与当前实际价值做误差
gym
import random
import time
from itertools import count
import gym
import torch
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make('CartPole-v0').unwrapped
env.reset()
env.render()
time.sleep(5)
env.reset()
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
_, screen_height, screen_width = screen.shape
screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]
for i in screen:
for j in i:
for k in j:
print(f"{k:3}", end=" ")
print()
env.reset()
n_actions = env.action_space.n
print(n_actions)
for i in tqdm(range(50)):
env.reset()
env.render()
for _ in count():
_, _, done, _ = env.step(torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long).item())
env.render()
if done:
break
env.close()
gym.make的参数位于conda相应环境的site-packages/gym/envs/__init__.py中,对应gym官网的环境
ReplayMemory
Transition = namedtuple('Transition',
('state', 'action', 'next_state', 'reward'))
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = deque([],maxlen=capacity)
def push(self, *args):
"""Save a transition"""
self.memory.append(Transition(*args))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
DQN
class DQN(nn.Module):
def __init__(self, h, w, outputs):
...
policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
select_action(state)
def select_action(state):
...
optimize_model
def optimize_model():
...
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
batch.next_state)), device=device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state
if s is not None])
...
state_action_values = policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
criterion = nn.SmoothL1Loss()
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
...
|