工具包Tools——LearningRateHistory学习率记录
import datetime
import os
import matplotlib.pyplot as plt
import scipy.signal
class LearningRateHistory:
def __init__(self, root):
"""
:param root: 根目录,主要是保存这个项目运行的所有结果的文件夹
"""
self.root = root
self.__assert_dir(root)
current_time = datetime.datetime.now()
str_time = datetime.datetime.strftime(current_time, '%Y_%m_%d_%H_%M_%S')
self.str_time = str_time
self.dir_name = 'lr_' + self.str_time
self.dirs = os.path.join(self.root, self.dir_name)
self.__assert_dir(self.dirs)
self.learning_rate_path = os.path.join(self.dirs, 'learning_rate.txt')
self.learning_rate_list = []
@staticmethod
def __assert_dir(dir_name):
if not os.path.exists(dir_name):
os.makedirs(dir_name)
def add_data(self, data):
self.learning_rate_list.append(data)
self.save_data(self.learning_rate_path, data)
self.plt_picture()
@staticmethod
def save_data(save_path, data):
with open(save_path, 'a', encoding='utf-8') as file:
file.write(str(data) + '\n')
def plt_picture(self):
length = range(len(self.learning_rate_list))
plt.figure('lr')
plt.plot(length, self.learning_rate_list, 'red', linewidth=2, label='learning rate')
try:
if len(self.learning_rate_list) < 25:
num = 5
else:
num = 15
plt.plot(length, scipy.signal.savgol_filter(self.learning_rate_list, num, 3), 'green', linestyle='--',
linewidth=2, label='smooth learning rate')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('lr')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.dirs, "epoch_loss_" + str(self.str_time) + ".png"))
plt.close('lr')
if __name__ == '__main__':
lr = LearningRateHistory(root='/lr')
lr.add_data(data=0.0005)
|