IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习学习工具包Tools——LearningRateHistory学习率记录工具 -> 正文阅读

[人工智能]深度学习学习工具包Tools——LearningRateHistory学习率记录工具

工具包Tools——LearningRateHistory学习率记录

import datetime
import os

import matplotlib.pyplot as plt
import scipy.signal


class LearningRateHistory:
    def __init__(self, root):
        """
        :param root: 根目录,主要是保存这个项目运行的所有结果的文件夹
        """
        self.root = root
        self.__assert_dir(root)

        current_time = datetime.datetime.now()
        str_time = datetime.datetime.strftime(current_time, '%Y_%m_%d_%H_%M_%S')
        self.str_time = str_time

        self.dir_name = 'lr_' + self.str_time
        self.dirs = os.path.join(self.root, self.dir_name)
        self.__assert_dir(self.dirs)

        self.learning_rate_path = os.path.join(self.dirs, 'learning_rate.txt')

        self.learning_rate_list = []

    @staticmethod
    def __assert_dir(dir_name):  # 判断文件夹是否存在
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)

    def add_data(self, data):  # 实例化后直接调用这个函数就行,添加loss数据
        self.learning_rate_list.append(data)
        self.save_data(self.learning_rate_path, data)

        self.plt_picture()

    @staticmethod
    def save_data(save_path, data):  # 保存数据
        with open(save_path, 'a', encoding='utf-8') as file:
            file.write(str(data) + '\n')

    def plt_picture(self):  # 绘制数据
        length = range(len(self.learning_rate_list))

        plt.figure('lr')
        plt.plot(length, self.learning_rate_list, 'red', linewidth=2, label='learning rate')  # 在画布上绘制曲线

        try:
            if len(self.learning_rate_list) < 25:
                num = 5
            else:
                num = 15

            plt.plot(length, scipy.signal.savgol_filter(self.learning_rate_list, num, 3), 'green', linestyle='--',
                     linewidth=2, label='smooth learning rate')
        except:
            pass

        plt.grid(True)  # 是否打开画布网格
        # 设置x轴和y轴
        plt.xlabel('Epoch')
        plt.ylabel('lr')
        # 图例的位置上,右上角
        plt.legend(loc="upper right")
        # 保存图片
        plt.savefig(os.path.join(self.dirs, "epoch_loss_" + str(self.str_time) + ".png"))
        plt.close('lr')

if __name__ == '__main__':
    # 使用例子
    lr = LearningRateHistory(root='/lr')
    lr.add_data(data=0.0005)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-29 09:05:19  更:2021-08-29 09:06:13 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 17:40:28-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码