?Keras框架中的History对象记录的是网络训练过程中的train_acc/train_loss/val_acc/val_loss等数值,而调整网络的超参数都是要通过这些数值进行调整。可视化该对象的代码,使我们对模型的训练情况有更加直观的认识和掌握,为下一次训练过程中的参数调节提供有价值的参考。
# -*- encoding: utf-8 -*-
'''
File:read_history.py
Author:Mr. Luo
Date:2022/3/1 14:09
Feature: 读取history数据文件并可视化
'''
import logging
logging.basicConfig(level=logging.DEBUG,
filemode='w',
format='%(asctime)s:%(levelname)s:%(message)s',
datefmt='%Y-%d-%m %H:%M:%S')
import matplotlib.pyplot as plt
import pickle
with open(r'E:\update_detection\data_create\trainHistoryDict.txt', 'rb') as f:
his_dict = pickle.load(f)
# 这两行代码解决 plt 中文显示的问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 输入纵坐标轴数据与横坐标轴数据
loss = his_dict.get("loss")
acc = his_dict.get("acc")
val_loss = his_dict.get("val_loss")
val_acc = his_dict.get("val_acc")
epoch = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# 4 个 plot 函数画出 4 条线,线形为折线,每条线对应各自的标签 label
plt.plot(epoch, loss, 'ro:', label='loss')
plt.plot(epoch, acc, 'r.-', label='acc')
plt.plot(epoch, val_loss, 'go:', label='val_loss')
plt.plot(epoch, val_acc, 'g.-', label='val_acc')
plt.xticks(epoch) # 设置横坐标刻度为给定的epoch
plt.xlabel('epoch') # 设置横坐标轴标题
plt.legend() # 显示图例,即每条线对应 label 中的内容
plt.show() # 显示图形
运行结果:
?
|