已经学好全部的表格方法训练Agent,所以,自己拿一个游戏环境进行智能体的训练。
一、选取游戏
打开 gym官网文档 ,里面就是对CartPole-v0 的介绍,感觉这个游戏也挺有意思。但是这个游戏环境的状态是连续的。不过从游戏的画面看,参数空间应该不是很大,所以决定探索整个状态空间之后再决定是否更换游戏。
import gym
import numpy as np
import matplotlib.pyplot as plt
env = gym.make('CartPole-v0')
s = env.reset()
print('High: ', env.observation_space.high,
'\nLow: ',env.observation_space.low)
"""
High: [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]
Low: [-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38]
"""
可以看出第一个状态元素的范围是 [-4.8, 4.8] ,第三个状态元素的范围是 [-0.4, 0.4] ,第二个和第四个元素没有上限需要看一下元素的分布。
env = gym.make('CartPole-v0')
s = env.reset()
print('High: ', env.observation_space.high,
'\nLow: ',env.observation_space.low)
np.round(-4.1887903e-01, 1)
ob_list = []
for _ in range(10000):
env.render()
a = env.action_space.sample()
n_state, reward, done, info = env.step(a)
ob_list.append(np.take(n_state, [1, 3]))
env.close()
ob_arr = np.stack(ob_list)
fig, axes = plt.subplots(2, 1, figsize=(16, 4))
for i in range(2):
axes[i].plot(ob_arr[:, i])
max_ = ob_arr[:, i].max()
min_ = ob_arr[:, i].min()
x = 1 if i == 0 else 3
axes[i].set_title(f'observation-[{x}] max: ${max_:.2f}$ min: ${min_:.2f}$')
axes[i].set_xticks([])
plt.show()
随机运动的分布有点尴尬。我们可以换一个思路, 如果我们训练有效的话,剩余的状态元素一定会停在一定的区间,所以我们现在可以假定状态元素1和状态元素3都在[-3, 3] 之间
二、构建智能体
因为我们的状态空间非常大, 如果用Actor-Critic 方法训练需要存两个"表格",构建初始状态表会非常耗时,并且环境的结束反馈和一般表格游戏不一样,该游戏都是以reward为0结束,所以用Actor-Critic 以及SARSA会学习不到太多信息。于是我们选择monta carlo方法进行迭代学习。 __make_Q_table 就是按刚刚的思路构建的状态空间会非常巨大,以精度0.1计算,就有3百70万多的状态。
s1_num = (4.8 + 4.8) * 10 + 2
s2_num = 6 * 10 + 2
s3_num = 0.8 * 10 + 2
s4_num = 6 * 10 + 2
"""
>>> s1_num*s2_num*s3_num*s4_num
3767120.0
"""
显然,这么庞大的状态空间,我们是无法进行很好的训练的(做好心理准备),但是一定可以训练出在一定的移动空间内,木板的平衡。
class CartPoleActor:
def __init__(self, env, epsilon=-1, round_num=2):
self.epsilon = epsilon
self.round_num = round_num
self.actions = list(range(env.action_space.n))
self.a_low, self.b_low = np.round(np.take(env.observation_space.low, [0, 2]), self.round_num)
self.a_high, self.b_high = np.round(np.take(env.observation_space.high, [0, 2]), self.round_num)
self.Q = self.__make_Q_table()
def get_distrubution_arr(self, a_low, a_high, round_num):
a_cnt = int((a_high - a_low) * (10 ** round_num) + 1)
a = np.round(np.linspace(a_low, a_high, a_cnt), round_num)
if not np.sum(0 == a):
a = np.concatenate([a, np.array([-0., 0.])])
else:
a = np.concatenate([a, np.array([-0.])])
return a
def __make_Q_table(self):
a = self.get_distrubution_arr(self.a_low, self.a_high, self.round_num)
b = self.get_distrubution_arr(-3., 3., self.round_num)
c = self.get_distrubution_arr(self.b_low, self.b_high, self.round_num)
d = self.get_distrubution_arr(-3., 3., self.round_num)
Q_dict = dict()
for s1 in a:
for s2 in b:
for s3 in c:
for s4 in d:
Q_dict[str(np.round(np.array([s1, s2, s3, s4]), self.round_num))
] = np.random.uniform(0, 1, len(self.actions))
print('len(Q_dict) = ', len(Q_dict))
return Q_dict
@staticmethod
def softmax(x):
return np.exp(x) / np.sum(np.exp(x), axis=0)
def policy(self, s):
return np.random.choice(self.actions, size=1,
p=self.softmax(self.Q[s]))[0]
二、训练智能体
class CartPoleMontoCarlo:
def __init__(self, actor_cls, round_num):
self.actor_cls = actor_cls
self.round_num = round_num
def take_state(self, state):
if type(state) == str:
return state
s1, s2, s3, s4 = np.round(state, self.round_num)
s1 = np.clip(s1, -4.8, 4.8)
s2 = np.clip(s3, -3., 3.)
s3 = np.clip(s2, -0.4, 0.4)
s4 = np.clip(s4, -3., 3.)
return str(np.round(np.array([s1, s2, s3, s4]), self.round_num))
def train(self, env, gamma=0.9, learning_rate=0.1, epoches=1000, render=False):
actor = self.actor_cls(env, round_num = self.round_num)
loop_cnt_list = []
for e in range(epoches):
s = env.reset()
done = False
loop_cnt = 0
state_list = []
action_list = []
reward_list = []
while not done:
if render and (e % 50 == 0):
env.render()
s = self.take_state(s)
a = actor.policy(s)
n_state, reward, done, info = env.step(a)
n_state = self.take_state(n_state)
state_list.append(s)
action_list.append(a)
reward_list.append(reward)
s = n_state
loop_cnt += 1
else:
loop_cnt_list.append(loop_cnt)
game_len = len(state_list)
for i in range(game_len):
s, a = state_list[i], action_list[i]
G, t = 0, 0
for j in range(i, game_len):
G += np.power(gamma, t) * reward_list[j]
t += 1
actor.Q[s][a] += learning_rate * (G - actor.Q[s][a])
if render and (e % 50 == 0):
env.close()
if e % 50 == 0:
m_ = np.mean(loop_cnt_list[-50:])
std_ = np.std(loop_cnt_list[-50:])
reward_m = np.mean(reward_list[-50:])
print(f'Epoch [{e}]: balance last {m_:.2f} (+/- {std_:.3f}) times - rewards {reward_m:.3f}')
return actor, loop_cnt_list, reward_list
trainer = CartPoleActorCritic(CartPoleActor, round_num=1)
env = gym.make('CartPole-v0')
actor, loop_cnt_list, reward_list = trainer.train(
env,
epoches=15000,
gamma=0.9,
learning_rate=0.1,
render=False)
smooth_cnt = []
for idx in range(len(loop_cnt_list)-1000):
smooth_cnt.append(np.mean(loop_cnt_list[idx:idx+1000]))
plt.plot(smooth_cnt)
plt.show()
从图中我们可以看出,我们训练的智能体能保持130回合左右的稳定。
观测训练出的智能体的表现
因为训练的智能体的状态空间有限,reset的观测的时候会时好时差,因为我们的平均良好表现在130回合,而现在我们观察240回合,不过多次尝试中也可以看到,有时候也是表现的不错的。
env = gym.make('CartPole-v0')
s = env.reset()
for _ in range(240):
env.render()
s_for_a = trainer.take_state(s)
a = actor.policy(s_for_a)
n_state, reward, done, info = env.step(a)
print(_, s_for_a, actor.Q[s_for_a], reward, done)
s = n_state
env.close()
四、深度强化学习的必要性
从上述的实践中我们可以看出,表格方法在这种连续的任务中难度非常大。所以需要深度强化学习的出场。等学好深度强化学习后再重新训练当前的游戏环境再做比较。
|