IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习入门系列9:用检查点保存训练期间最好的模型 -> 正文阅读

[人工智能]深度学习入门系列9:用检查点保存训练期间最好的模型

大家好,我技术人Howzit,这是深度学习入门系列第八篇,欢迎大家一起交流!

深度学习入门系列1:多层感知器概述
深度学习入门系列2:用TensorFlow构建你的第一个神经网络
深度学习入门系列3:深度学习模型的性能评价方法
深度学习入门系列4:用scikit-learn找到最好的模型
深度学习入门系列5项目实战:用深度学习识别鸢尾花种类
深度学习入门系列6项目实战:声纳回声识别
深度学习入门系列7项目实战:波士顿房屋价格回归
深度学习入门系列8:用序列化保存模型便于继续训练
深度学习入门系列9:用检查点保存训练期间最好的模型
待更新……
深度学习入门系列10:从绘制记录中理解训练期间的模型行为
深度学习入门系列11:用Dropout正则减少过拟合
深度学习入门系列12:使用学习规划来提升性能
深度学习入门系列13:卷积神经网络
深度学习入门系列14项目实战:手写数字识别
深度学习入门系列15用图像增强改善模型性能
深度学习入门系列16:图像中对象识别项目
深度学习入门系列17项目实战:从电影评论预测情感
深度学习入门系列18:递归神经网络


深度学习模型的训练需要花费几个小时,几天甚至几周来训练,若是训练过程意外中断,你会丢失很多工作。在这部分,你将学习在python中如何使用Keras库设立训练模型的检查点。在学习完这节课之后,你将学到:

  • 在训练时,设置神经网络模型检查点机制的重要性。
  • 在训练时,如何为每个改善模型的设置检查点。
  • 在训练时,如何设置检最好模型的检查点。

让我们开始。

9.1 设置神经网络模型检查点

对于耗时处理,应用检查点机制是一种故障容错技术。它是一种在系统出错情况下快照系统状态的方法。如果有问题,不会全部丢失。检查点可以直接使用,或者作为新程序的起点,在他中断的地方加载数据。当训练深度学习模型,检查点能捕获模型权重。这些权重用于做预测,或者用于继续训练的基础。

Keras库通过回调API提供了检查点功能。ModelCheckpoint 回调类允许你定义在哪设置权重检查点,如何命名文件,在什么情况下做模型检查点。API允许指定记录的度量,如训练集或者交叉数据集的误差或者精度。您也可以指定是否寻找最大值或最小评分的改进。最后,你用于保存权重的文件名可以包含变量,像迭代次数或者度量。当在模型调用**fit()**函数时,ModelCheckpoint 实例被传入到整个处理过程中。注意,你可能需要安装h5py库。

9.2检查点神经网络模型改善

检查点的用途是每次在训练期间观察到改进时输出模型权重。下面的例子针对糖尿病发病的二分类问题创建了一个小的神经网络。这个例子中使用33%的数据来验证。

检查点的建立是为了在验证数据集(monitor=‘val_acc’and mode=‘max’)上分类精度有改进时保存神经网络权重。这些权重被存在文件中,它的名字包含分数,weights-improvements-epoch-val_acc=.2f.hdf5

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ModelCheckpoint

import matplotlib.pyplot as plt
import numpy

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=',')
print(dataset)

X = dataset[:, 0:8]
Y = dataset[:, 8]

# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# compile model
model.compile(loss="binary_crossentropy", optimizer='adam', metrics=['accuracy'])
filepath = "weights-improvement-{epoch:02d}-{val_loss:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, model='max')
callbacks_list = [checkpoint]
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)

运行这个例子会产生下面这个结果,为了简洁明了做了删减。在结果中,你能够看到在验证集上模型精度的改善导致一个新的权重文件写入到磁盘。

...
Epoch 00139: val_acc improved from 0.78346 to 0.78740, 
saving model to weights-improvement-139-0.79.hdf5 
Epoch 00140: val_acc did not improve 
Epoch 00141: val_acc did not improve 
Epoch 00142: val_acc did not improve 
Epoch 00143: val_acc did not improve 
Epoch 00144: val_acc improved from 0.78740 to 0.79528, 
saving model to weights-improvement-144-0.80.hdf5

Epoch 00145: val_acc did not improve 
Epoch 00146: val_acc did not improve 
Epoch 00147: val_acc did not improve 
Epoch 00148: val_acc did not improve 
Epoch 00149: val_acc did not improve

你也将在项目文件夹下看到一些文件,包括HDF5格式的网络权重,例如:

... 
weights-improvement-53-0.76.hdf5 
weights-improvement-71-0.76.hdf5 
weights-improvement-77-0.78.hdf5 
weights-improvement-99-0.78.hdf5

这个是非常简单的检查点设置策略。如果验证集精度随着迭代周期波动很大的话,可能会产生许多不必要的文件。因此,它将确保在运行期间对最好的模型进行快照。

9.3 仅设置最好神经网路模型的检查点

一个比较简单的检查点策略是保存模型权重到同一文件,仅仅只要验证集精度提升。这很容易做到,使用上面相同的代码,把输出文件名固定即可(不包含分数或迭代信息。)。这种情况下,仅当验证数据集上的模型的分类精度比目前为止所看到的最好的还好时,模型权重才被写入到weights.best.hdf5中。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import numpy

# fix random seed for reproducibility

seed = 7
numpy.random.seed(seed)
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=',')
print(dataset)

X = dataset[:, 0:8]
Y = dataset[:, 8]

# create model

model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# compile model
model.compile(loss="binary_crossentropy", optimizer='adam', metrics=['accuracy'])

filepath = "weights-best.hdf5"

checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, model='max')

callbacks_list = [checkpoint]

model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)

运行上面例子得到如下输出(为简洁起见)

... 
Epoch 00139: val_acc improved from 0.79134 to 0.79134, saving model to weights.best.hdf5 
Epoch 00140: val_acc did not improve 
Epoch 00141: val_acc did not improve 
Epoch 00142: val_acc did not improve 
Epoch 00143: val_acc did not improve 
Epoch 00144: val_acc improved from 0.79134 to 0.79528, saving model to weights.best.hdf5 
Epoch 00145: val_acc improved from 0.79528 to 0.79528, saving model to weights.best.hdf5 
Epoch 00146: val_acc did not improve 
Epoch 00147: val_acc did not improve 
Epoch 00148: val_acc did not improve 
Epoch 00149: val_acc did not improve
Listing 14.5: Sample Output 

你应该在本地目录中看到权重文件

weights.best.hdf5

9.4 加载已保存神经网络模型

既然你已经知道如何在训练过程中设置你的深度学习模型的检查点,那么你需要复习如何加载并使用这个检查点模型。这个检查点模型仅仅包含模型权重。它假定你知道网络结构,这也可以序列化为JSON文件或YAML格式。下面这个例子,模型的结构是知道的,从工作目录的 weights.best.hdf5文件加载上面实验中最好的权重参数。这个模型被用于在整个数据集上做预测。

# How to load and use weights from a checkpoi

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation="relu"))
model.add(Dense(8, activation="relu"))
model.add(Dense(1, activation="sigmoid"))

# load weights
model.load_weights("weights.best.hdf5")

# Compile model (required to make predictions)

model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])

print("Created model and loaded weights from file")

# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")

# split into input (X) and output (Y) variables
X = dataset[:, 0:8]
Y = dataset[:, 8]

# estimate accuracy on whole dataset using loaded weights
scores = model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1] * 100))

运行例子得到下面结果

Created model and loaded weights from file acc: 77.34%

9.5 总结

在这节课中,你已经学习了在耗时中设置深度学习模型检查点的重要性。你已学到:

  • 如何使用Keras来为每次模型改进设置检查点。
  • 如何在训练中仅仅为最好的模型设置检查点。
  • 如何从文件中加载检查点模型并且用它来做预测。

9.5.1

在长时间训练模式中,你现在知道如何设置你深度学习模型检查点。在下一课中,您将学习如何在训练期间收集,检查和绘制有关模型的度量标准

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-26 12:06:37  更:2021-08-26 12:06:44 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 21:53:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码