IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于深度强化学习的绘画智能体 代码分析(四) -> 正文阅读

[人工智能]基于深度强化学习的绘画智能体 代码分析(四)

Github源码

tensorboard.py

import PIL #图像处理库
import scipy.misc #将数组保存成图像形式
from io import BytesIO #在内存中读写bytes
import tensorboardX as tb
from tensorboardX.summary import Summary

class TensorBoard(object):
    def __init__(self, model_dir): #model_dir是下载模型保存地址
        self.summary_writer = tb.FileWriter(model_dir) #指定一个文件用来保存图

    def add_image(self, tag, img, step):
        summary = Summary()
        bio = BytesIO() #创建一个类二进制文件对象

        if type(img) == str:
            img = PIL.Image.open(img)  #返回PIL.Image.Image的类型
        elif type(img) == PIL.Image.Image:
            pass #不需要转换
        else:
            img = PIL.Image.fromarray(img) #array转换成image

        img.save(bio, format="png")
        image_summary = Summary.Image(encoded_image_string=bio.getvalue()) #可视化
        summary.value.add(tag=tag, image=image_summary) #按照标签加入进去
        self.summary_writer.add_summary(summary, global_step=step)  #global_step训练步数

    def add_scalar(self, tag, value, step): #加的是scalar具体的值
        summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)])
        self.summary_writer.add_summary(summary, global_step=step)

Github源码
util.py

import os
import torch
from torch.autograd import Variable

USE_CUDA = torch.cuda.is_available() #你电脑GPU能否PyTorch调用。

def prRed(prt): print("\033[91m {}\033[00m" .format(prt))
def prGreen(prt): print("\033[92m {}\033[00m" .format(prt))
def prYellow(prt): print("\033[93m {}\033[00m" .format(prt))
def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt))
def prPurple(prt): print("\033[95m {}\033[00m" .format(prt))
def prCyan(prt): print("\033[96m {}\033[00m" .format(prt))
def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt))
def prBlack(prt): print("\033[98m {}\033[00m" .format(prt))

def to_numpy(var): #把tensor变成numpy
    return var.cpu().data.numpy() if USE_CUDA else var.data.numpy() 
#.data是读取Variable中的tensor   .cpu是把数据转移到cpu    .numpy()把tensor变成numpy

def to_tensor(ndarray, device): #和上面的相反
    return torch.tensor(ndarray, dtype=torch.float, device=device)

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()): #parameters()会返回一个生成器(迭代器),生成器每次生成的是Tensor类型的数据.
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau  #加了tau(0~1),复制一部分
        )

def hard_update(target, source):
    for m1, m2 in zip(target.modules(), source.modules()):
        m1._buffers = m2._buffers.copy()
    for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(param.data) #source.parameters一对一的全部复制到target_param

def get_output_folder(parent_dir, env_name):
    """Return save folder. #返回保存文件夹。
    Assumes folders in the parent_dir have suffix -run{run
    number}. #假定父目录中的文件夹具有后缀-run{run number}
   Finds the highest run number and sets the output folder
    to that number + 1. #查找最高的运行编号,并将输出文件夹设置为该编号+1。
   This is just convenient so that if you run the
    same script multiple times tensorboard can plot all of the results
    on the same plots with different names. #这非常方便,如果您多次运行同一脚本,tensorboard可以使用不同的名称在相同的绘图上绘制所有结果。
    Parameters
    ----------
    parent_dir: str
      Path of the directory containing all experiment runs.
    Returns
    -------
    parent_dir/run_dir
      Path to this run's save directory.
    """
    os.makedirs(parent_dir, exist_ok=True) #创建目录
    experiment_id = 0
    for folder_name in os.listdir(parent_dir):
        if not os.path.isdir(os.path.join(parent_dir, folder_name)):
            continue
        try:
            folder_name = int(folder_name.split('-run')[-1])   #获取文件扩展名
            if folder_name > experiment_id:
                experiment_id = folder_name
        except:
            pass
    experiment_id += 1

    parent_dir = os.path.join(parent_dir, env_name)
    parent_dir = parent_dir + '-run{}'.format(experiment_id)
    os.makedirs(parent_dir, exist_ok=True)
    return parent_dir
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-12 23:26:02  更:2021-10-12 23:27:24 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 10:17:22-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码