stable-baselines3学习之Tensorboard系列
1.基本用法
要使用stable-baselines3的 Tensorboard,您只需将日志文件夹的位置传递给 RL 的agent:
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000)
您还可以在训练时定义自定义日志名称(默认为算法名称)
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000, tb_log_name="first_run")
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
调用 learn 函数后,您可以使用以下 bash 命令在训练期间或之后监控 RL agent:
tensorboard --logdir ./a2c_cartpole_tensorboard/
注:要在该项目文件路径下运行这条命令
比如:
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WpecFpF9-1647931462737)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322120347758.png)]](https://img-blog.csdnimg.cn/7dd67fb02179472eb9a15ee1caf080db.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M7dMwEEI-1647931462738)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322120409523.png)]](https://img-blog.csdnimg.cn/3865980bd518426492b55cea5fa6b4ce.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
2.Logging More Values
使用callback可以容易的记录更多日志用Tensorboard,这里有一个简单的例子去记录额外的tensor和任意的scalar值:
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac0/", verbose=1)
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super(TensorboardCallback, self).__init__(verbose)
def _on_step(self) -> bool:
value = np.random.random()
self.logger.record('random_value', value)
return True
model.learn(50000, callback=TensorboardCallback())
tensorboard --logdir ./tmp/sac0/
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pIq7VA1T-1647931462738)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322124527536.png)]](https://img-blog.csdnimg.cn/e0ea440df86f40cea4888794d8b5bd18.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-utw3W67U-1647931462739)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322124553448.png)]](https://img-blog.csdnimg.cn/7be1a3411f1a46289f4ac3e345f63960.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
3.Logging Images
TensorBoard 支持定期记录图像数据,这有助于在训练期间的各个阶段评估agent。
以下是如何定期将图像渲染到 TensorBoard 的示例:
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Image
model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac1/", verbose=1)
class ImageRecorderCallback(BaseCallback):
def __init__(self, verbose=0):
super(ImageRecorderCallback, self).__init__(verbose)
def _on_step(self):
image = self.training_env.render(mode="rgb_array")
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
return True
model.learn(50000, callback=ImageRecorderCallback())
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ocQxw7ue-1647931462739)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322130204745.png)]](https://img-blog.csdnimg.cn/60e286976e8f45e1b672a6f6907c3d90.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
tensorboard --logdir ./tmp/sac1/
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eAsoEg4Q-1647931462740)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322131447492.png)]](https://img-blog.csdnimg.cn/5e2b4db7c6c84adf93cd7d6b32fc418a.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
4.Logging Figures/Plots
TensorBoard 支持定期记录使用 matplotlib 创建的图形/绘图,这有助于在训练期间评估各个阶段的agent。
以下是如何在 TensorBoard 中定期存储绘图的示例:
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Figure
model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac2/", verbose=1)
class FigureRecorderCallback(BaseCallback):
def __init__(self, verbose=0):
super(FigureRecorderCallback, self).__init__(verbose)
def _on_step(self):
figure = plt.figure()
figure.add_subplot().plot(np.random.random(3))
self.logger.record("trajectory/figure", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv"))
plt.close()
return True
model.learn(50000, callback=FigureRecorderCallback())
tensorboard --logdir ./tmp/sac1/
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IGphPL0B-1647931462740)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322134756257.png)]](https://img-blog.csdnimg.cn/5c575c6ec1ac437aacb0f4410025632a.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9rxhe5yd-1647931462741)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322134906652.png)]](https://img-blog.csdnimg.cn/3a898da5d33542409801cf251b22b5d6.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
5.Logging Videos
TensorBoard 支持定期记录视频数据,这有助于在训练期间评估各个阶段的agent。
以下是如何显示一个episode并将生成的视频定期记录到 TensorBoard 的示例:
注:需安装moviepy 包
from typing import Any, Dict
import gym
import torch as th
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import Video
class VideoRecorderCallback(BaseCallback):
def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
"""
Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard
:param eval_env: A gym environment from which the trajectory is recorded
:param render_freq: Render the agent's trajectory every eval_freq call of the callback.
:param n_eval_episodes: Number of episodes to render
:param deterministic: Whether to use deterministic or stochastic policy
"""
super().__init__()
self._eval_env = eval_env
self._render_freq = render_freq
self._n_eval_episodes = n_eval_episodes
self._deterministic = deterministic
def _on_step(self) -> bool:
if self.n_calls % self._render_freq == 0:
screens = []
def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
"""
Renders the environment in its current state, recording the screen in the captured `screens` list
:param _locals: A dictionary containing all local variables of the callback's scope
:param _globals: A dictionary containing all global variables of the callback's scope
"""
screen = self._eval_env.render(mode="rgb_array")
screens.append(screen.transpose(2, 0, 1))
evaluate_policy(
self.model,
self._eval_env,
callback=grab_screens,
n_eval_episodes=self._n_eval_episodes,
deterministic=self._deterministic,
)
self.logger.record(
"trajectory/video",
Video(th.ByteTensor([screens]), fps=40),
exclude=("stdout", "log", "json", "csv"),
)
return True
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="./tmp/runs/", verbose=1)
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
model.learn(total_timesteps=int(5e4), callback=video_recorder)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YficDhFG-1647931462741)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322142640752.png)]](https://img-blog.csdnimg.cn/f9e80baff9f041719fbb7e33bae54b01.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
tensorboard --logdir ./tmp/runs/

|