IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 基于PARL的cartpole实现+visualdl画图 -> 正文阅读

[Python知识库]基于PARL的cartpole实现+visualdl画图

在这里插入图片描述

首先新建一个文件夹my_cartpole

里面共有三个py文件,使用vim agent.py来创建
在这里插入图片描述

model.py

这个文件主要是用来定义前向网络,通常是一个值函数网络,输入是当前环境状态

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import parl

##继承parl.Model类
class my_cartpole(parl.Model):      
##构造函数__init__中声明要用到的中间层
    def __init__(self,obs_dim,act_dim):
        super(my_cartpole,self).__init__()
        hid1_size = act_dim*10
        self.fc1 = nn.Linear(obs_dim,hid1_size)
        self.fc2 = nn.Linear(hid1_size,act_dim)
##搭建前向网络
    def forward(self,x):
        out = paddle.tanh(self.fc1(x))  ##一层fc和tanh激活函数
        prob = F.softmax(self.fc2(out),axis=-1)##一层FC和softmax激活函数
        return prob

Agent.py

init :把前边定义好的alg传进来,作为agent的一个成员变量,用于后续的数据交互。
predict: 根据环境状态返回预测动作action,用于评估和部署agent
sample: 根据环境状态返回动作action,一般用于训练时候采样action进行探索。

import parl
import paddle
import numpy as np

class my_cartpole_agent(parl.Agent):
    def __init__(self,algorithm):
        super(my_cartpole_agent,self).__init__(algorithm)

##产生相对随机的动作两条路
    def sample(self,obs):
        obs = paddle.to_tensor(obs, dtype='float32')
        prob = self.alg.predict(obs)
        prob = prob.numpy()
        act = np.random.choice(len(prob),1,p=prob)[0] ##产生随机的动作
        return act
        
##根据argmax选择最优动作
    def predict(self,obs):
        obs = paddle.to_tensor(obs,dtype='float32')
        prob = self.alg.predict(obs)
        act = prob.argmax().numpy()[0]
        return act

    def learn(self,obs,act,reward):
        act = np.expand_dims(act,axis=-1)		
        reward = np.expand_dims(reward,axis=-1)
        
        obs = paddle.to_tensor(obs,dtype='float32')		##将numpy转换为tensor
        act = paddle.to_tensor(act,dtype='int32')
        reward = paddle.to_tensor(reward,dtype='float32')
        
        loss = self.alg.learn(obs,act,reward)
        return loss.numpy()[0]

train.py

from visualdl import LogWriter  ##用于在visualdl中画图
from visualdl.server import app  ##用于直接启动网站
import os	##文件路径
import gym	
import numpy as np
import parl
from parl.utils import logger		##为了打印
from agent import my_cartpole_agent		##从agent.py中调用my_cartpole_agent类
from model import my_cartpole				##从model.py中调用my_cartpole类

LEARNING_RATE = 1e-3

log_writer = LogWriter("./home/my_cartpole")

## 训练一次
def run_train_episode(agent,env):
    obs_list,action_list,reward_list = [],[],[]
    obs = env.reset()

    while True:
        obs_list.append(obs)	##就是在list后边加
        action = agent.sample(obs)
        action_list.append(action)

        obs,reward,done,info=env.step(action)	##做动作,拿返回值
        reward_list.append(reward)

        if done:
            break

    return obs_list,action_list,reward_list

def calc_reward_to_go(reward_list,gamma=1.0):
    for i in range(len(reward_list)-2,-1,-1):
        reward_list[i] += gamma * reward_list[i+1]

    return np.array(reward_list)	##从[1,2]array([1,2])

def run_test_episodes(agent,env,eval_episodes=5,render=False):
    eval_reward = []
    for i in range(eval_episodes):
        obs = env.reset()
        episode_reward = 0
        while True:
            action = agent.predict(obs)
            obs,reward,isOver,_=env.step(action)
            episode_reward+=reward
            if render:
                env.render()
            if isOver:
                break
        eval_reward.append(episode_reward) ##为了后边的调用mean
    return np.mean(eval_reward)

##其实就是主函数
env = gym.make("CartPole-v0")
obs_dim = env.observation_space.shape[0] ##从环境中获取状态——当前的状态
act_dim = env.action_space.n ##获取动作维度

model = my_cartpole(obs_dim=obs_dim,act_dim=act_dim)	##前向网络定义完毕
alg = parl.algorithms.PolicyGradient(model,lr=LEARNING_RATE) ##算法调用PolicyGradient
agent = my_cartpole_agent(alg)	##把算法加到智能体上去

agent.restore('./model_save')##重新加载上次的模型

for i in range(1000):
    obs_list,action_list,reward_list = run_train_episode(agent,env) ##训练用sample

    if i % 10 == 0:
        logger.info("Episode{},Reward Sum{}.".format(i,sum(reward_list)))
        log_writer.add_scalar(tag='reward',step=i,value=sum(reward_list)) ##数据标签,为了visualdl

    batch_obs = np.array(obs_list)
    batch_action = np.array(action_list)
    batch_reward = calc_reward_to_go(reward_list)

    agent.learn(batch_obs,batch_action,batch_reward)

    if(i+1) % 100 == 0:
        total_reward = run_test_episodes(agent,env)		##测试用predict
        logger.info('Test reward:{}'.format(total_reward))

agent.save('./model_save')	##保存模型
app.run(logdir="./home/my_cartpole") ##日志文件路径


在这里插入图片描述

子函数:
在这里插入图片描述
什么叫张量:

在这里插入图片描述
array中的数据类型必须都一样;list则不必要
axis=0:在第一维操作,三个矩阵中最大的
axis=1:在第二维操作,每个矩阵的列最大
axis=-1:在最后一维操作,每个矩阵的行最大
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
解释expand_dims
在这里插入图片描述
解释to_tensor
在这里插入图片描述

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-08-22 13:29:59  更:2021-08-22 13:32:20 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/26 11:43:07-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计