IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> TensorFlow Estimator 中的模型保存为Checkpoints格式 -> 正文阅读

[人工智能]TensorFlow Estimator 中的模型保存为Checkpoints格式

本文介绍了 Estimators 模型的保存和恢复。

TensorFlow提供了两种模型格式:

  • checkpoints:这种格式依赖于创建模型的代码。
  • SavedModel:这种格式与创建模型的代码无关。

本文档主要介绍checkpoints。

1. 保存经过部分训练的模型

Estimators 在训练过程中会自动将以下内容保存到磁盘:

  • chenkpoints:训练过程中的模型快照。
  • event files:其中包含 TensorBoard 用于创建可视化图表的信息。

通过 model_dir 参数,我们可以指定 Estimator 保存上述文件时的顶级目录。

# 实例化 estimator
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

# 训练 estimator
classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
    steps=200)

如下图所示,第一次调用?train?方法会将 checkpoints 和 event files 文件添加到?model_dir?目录中。

来查看?model_dir?目录中的内容:

?我们可以看到,Estimator在step 1(训练开始)和step 12130(训练结束)创建了 checkpoints 文件。

1.2 创建 Checkpoints 的频率

默认情况下,Estimator 会根据以下策略来写入 checkpoints。

  • 每10分钟(600秒)向磁盘写入一个 checkpoint。
  • 在 train 方法开始(第一次迭代)和结束(最后一次迭代)时写入一个 checkpoint。
  • model_dir 目录中保留 5 个最近写入的检查点。

当然,你可以按如下方式修改 checkpoint 的写入策略:

  • 创建一个tf.estimator.RunConfig对象来定义 checkpoint 写入策略。
  • 在实例化 Estimator 时,将 RunConfig 对象传给 Estimator 的 config 参数。

下面的代码将 checkpoint 写入间隔设置为20分钟,并且保留最近的10个 checkpoints:

est_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # 每20分钟保存一次 checkpoints
    keep_checkpoint_max = 10,       # 保留最新的10个checkpoints
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=est_config)

?3. 恢复模型

第一次调用 Estimator 的 train 方法时,TensorFlow会保存 checkpoint 文件到 model_dir 目录。随后调用 tarin、evaluate、predict 方法将进行如下操作:

  • Estimator 通过运行 model_fn 来构建模型的计算图。
  • Estimator 从 checkpoints 中初始化模型参数。

?2.1 避免不当恢复

仅在模型和checkpoint兼容的情况下,才能从 checkpoint 恢复模型的状态。例如,假设您训练了DNNClassifier包含两个隐藏层的 Estimator,每个隐藏层有10个节点:

classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

训练之后,如果您将每个隐藏层中的神经元数量从10更改为20,并尝试重新训练模型:

classifier2 = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[20, 20],  # Change the number of neurons in the model.
    n_classes=3,
    model_dir='models/iris')

classifier2.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

由于 checkpoint 中的状态与描述的模型不兼容,因此重新训练失败并出现以下错误:

...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]

如果保存报错,粗暴的方式是:将原来保存的模型文件删除掉,新文件下保存模型不会出错。

参考:

tensorflow中模型的保存与使用总结 — carlos9310

https://blog.csdn.net/u014061630/article/details/82901646

保存和恢复

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-09 11:39:17  更:2021-12-09 11:39:35 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 1:37:44-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码