IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> OpenAI Gym|cart-pole-v1任务的环境源码 -> 正文阅读

[Python知识库]OpenAI Gym|cart-pole-v1任务的环境源码

?本文代码来源于Gym官方文档

https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.pyhttps://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py

cart-pole-v1任务的实现见pytorch实现CartPole-v1任务的DQN代码_bujbujbiu的博客-CSDN博客

描述

一根杆子由一个非驱动的接头连接到一辆小车上,小车沿着无摩擦的轨道移动。杆子被垂直放置在手推车上,目标是通过在手推车上施加左右方向的力来平衡杆子。

Action Space

动作是shape为(1,?)的ndarray数组,可以取值{0,1},表示小车被施加力的方向

?施加力所减少或增加的速度不是固定的,它取决于杆子指向的角度。杆子重心改变了移动下面的手推车所需的能量

Observation Space

状态是shape为(4, )的ndarray数组,包括小车位置,小车速度,杆子角度,杆子角速度

?上述定义的范围只是状态空间中各要素的可能取值,但是不是episode运行允许的范围,终止条件如下:

(1)小车x轴的位置(index 0)可以取值(-4.8,4.8),但是如果小车离开(-2.4,2.4)的范围,episode终止

(2)杆子角度可以在(-0.418, 0.418)?radians (or **±24°**)间,但是如果杆子超过(-0.2095, 0.2095)?(or **±12°**)范围,episode终止

Rewards

训楼目标是尽可能久的保持杆子不倒,因此每步都能获得+1的奖励,包括终止步,奖励阈值475

初始状态

所有观察值都被赋于(-0.05,0.05)中的一个均匀随机值

Episode终止

有下列情形之一的,episode终止:

(1)杆子角度大于±12°

(2)小车位置大于±2.4(小车中心到达显示屏边缘)

(3)episode长度大于500 (v0为200)

参数

gym.make('CartPole-v1')

完整代码

import math
from typing import Optional, Union

import numpy as np
import pygame
from pygame import gfxdraw

import gym
from gym import spaces, logger
from gym.utils import seeding

class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):

    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}

    def __init__(self):
        # 以下参数用于执行动作函数中,计算施加一定力对小车和杆子的影响
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masspole + self.masscart
        self.length = 0.5  # 杆子长度一半
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02  # 状态更新时间间隔(秒)
        self.kinematics_integrator = "euler"

        # 杆子角度阈值=12度,小车位置阈值=2.4
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        self.x_threshold = 2.4

        # 定义Observation Space的四个要素
        high = np.array(
            [
                self.x_threshold * 2, # 小车位置4.8
                np.finfo(np.float32).max, # 返回float32类型数据最大值
                self.theta_threshold_radians * 2, # 杆子角度24度
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
        # 定义action space,Discrete(2)={0,1}
        self.action_space = spaces.Discrete(2)

        self.screen = None
        self.clock = None
        self.isopen = True
        self.state = None

        self.steps_beyond_done = None

    def step(self, action):
        # assert相当于if else语句,满足前面条件则正常运行,否则报错或中断
        err_msg = f"{action!r} ({type(action)}) invalid"
        assert self.action_space.contains(action), err_msg
        assert self.state is not None, "Call reset before using step method."
        x, x_dot, theta, theta_dot = self.state
        # 力向右为正,像左为负
        force = self.force_mag if action == 1 else -self.force_mag
        costheta = math.cos(theta)
        sintheta = math.sin(theta)

        # 施加力对杆子和小车影响的数学公式https://coneural.org/florian/papers/05_cart_pole.pdf
        temp = (
            force + self.polemass_length * theta_dot ** 2 * sintheta
        ) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (
            self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
        )
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        # 更新状态值
        if self.kinematics_integrator == "euler":
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        self.state = (x, x_dot, theta, theta_dot)

        # 判断是否出现终止条件
        done = bool(
            x < -self.x_threshold
            or x > self.x_threshold
            or theta < -self.theta_threshold_radians
            or theta > self.theta_threshold_radians
        )

        # 根据执行动作后的状态计算奖励函数
        if not done:
            reward = 1.0
        elif self.steps_beyond_done is None:
            # Pole just fell!
            self.steps_beyond_done = 0
            reward = 1.0
        else:
            if self.steps_beyond_done == 0:
                logger.warn(
                    "You are calling 'step()' even though this "
                    "environment has already returned done = True. You "
                    "should always call 'reset()' once you receive 'done = "
                    "True' -- any further steps are undefined behavior."
                )
            self.steps_beyond_done += 1
            reward = 0.0
        # 返回执行一个动作后的新状态,奖励,是否终止
        return np.array(self.state, dtype=np.float32), reward, done, {}
    # 重置环境
    def reset(
        self,
        *,
        seed: Optional[int] = None,
        return_info: bool = False,
        options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
        self.steps_beyond_done = None
        if not return_info:
            return np.array(self.state, dtype=np.float32)
        else:
            return np.array(self.state, dtype=np.float32), {}
    # 图像引擎,用于展示训练过程中物体的变化(可不要)
    def render(self, mode="human"):
        screen_width = 600
        screen_height = 400

        world_width = self.x_threshold * 2
        scale = screen_width / world_width
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        if self.state is None:
            return None

        x = self.state

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((screen_width, screen_height))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((screen_width, screen_height))
        self.surf.fill((255, 255, 255))

        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cartx = x[0] * scale + screen_width / 2.0  # MIDDLE OF CART
        carty = 100  # TOP OF CART
        cart_coords = [(l, b), (l, t), (r, t), (r, b)]
        cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
        gfxdraw.aapolygon(self.surf, cart_coords, (0, 0, 0))
        gfxdraw.filled_polygon(self.surf, cart_coords, (0, 0, 0))

        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )

        pole_coords = []
        for coord in [(l, b), (l, t), (r, t), (r, b)]:
            coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
            coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
            pole_coords.append(coord)
        gfxdraw.aapolygon(self.surf, pole_coords, (202, 152, 101))
        gfxdraw.filled_polygon(self.surf, pole_coords, (202, 152, 101))

        gfxdraw.aacircle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )
        gfxdraw.filled_circle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )

        gfxdraw.hline(self.surf, 0, screen_width, carty, (0, 0, 0))

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        if mode == "human":
            pygame.event.pump()
            self.clock.tick(self.metadata["render_fps"])
            pygame.display.flip()

        if mode == "rgb_array":
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
            )
        else:
            return self.isopen

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()
            self.isopen = False

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-03-30 18:19:14  更:2022-03-30 18:23:24 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 4:37:25-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计