IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 快速上手tensorflw:模型保存、加载 -> 正文阅读

[人工智能]快速上手tensorflw:模型保存、加载

1. 代码实例

(1) 存储模型

举一个简单的前向传播的例子:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np

# 前向传播
def forward_pass(x11, w11, w22):
    y11 = tf.matmul(x11, w11, name="y11")
    y22 = tf.matmul(y11, w22, name="y22")
    return y22


# 定义权值
w1 = tf.Variable(tf.random.normal(shape=[2,5]), name='w1')
w2 = tf.Variable(tf.random.normal(shape=[5,2]), name='w2')
# 定义输入
x1 = tf.placeholder(tf.float32, shape=(None, 2), name="x1")

# 输出
y2 = forward_pass(x1, w1, w2)

# 未封装成函数时,输出的写法
# y1 = tf.matmul(x1,w1, name="y1")
# y2 = tf.matmul(y1,w2, name="y2")

# 建立session,初始化变量
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# 开始run
input_x = np.zeros((3,2))
dict = {x1:input_x}
sess.run(y2,dict)

# 保存模型
saver.save(sess, 'Test/TestModel')

这里把前向传播的过程封装成了函数,更加符合平时写程序时的习惯。

我们在保存模型后,会出现一些文件:

?其中:meta文件是用来存储?图的计算过程?的,data里则存储了?权值?的具体值。

(2)加载模型:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np

# 加载图
sess=tf.Session()
#先加载图和参数变量
saver = tf.train.import_meta_graph('Test/TestModel.meta')
saver.restore(sess, tf.train.latest_checkpoint('Test'))

# 访问placeholders变量
graph = tf.get_default_graph()
# 输入的变量、以及权值
x1 = graph.get_tensor_by_name("x1:0")
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
print(sess.run('w1:0'))
print(sess.run('w2:0'))

# 获取操作
y1 = graph.get_tensor_by_name("y11:0")
y2 = graph.get_tensor_by_name("y22:0")

# 这样的话不用自己写结构,也可以根据输入直接得到输出
input_x = np.zeros((3,2))
feed_dict = {x1:input_x}
print(sess.run(y2,feed_dict))

2. 详细解释

(1) name属性

①在一开始写程序时,就需要在相应变量上添加:name属性,如输入变量、网络权值、前向传播的输出值、损失值等。在加载模型时,我们需要根据name属性,导入对应的变量。

值得注意的是,name属性需要在tf变量上添加,这样我们在计算想要复现的变量时,需要使用tensorflow的计算语法。

例如:我们想在加载模型的程序中导入a的计算图

写加法时,不能:

a = b + c

而是需要:

a = tf.add_n([b, c], name = 'a')

乘法等计算过程亦是如此。

②tensorflow是根据计算图来运行的,而计算任何变量都是需要用到一开始的输入的。因此如果以后要加载模型,那我们在写程序时,需要对?输入?添加name属性。其他的变量则是用到什么,就给它添加name属性。

(2) 加载模型后的计算

①我们利用tf.train.import_meta_graph('Test/TestModel.meta')导入了计算图,这样所有的计算过程我们都不用再自己写,想计算哪个变量,就直接:?sess.run(目标变量的name, feed_dict)

②tensorflow甚至可以保存方法中的局部变量的计算图,例如程序中的y11是个局部变量,可它的计算过程也能被存储下来,确实是神奇。

(3) 个人遇到的多个输入的情况(可以不看)

当我们有多个输入都要进行前向传播、且前向传播为函数形式,这时存储的计算图为第一次调用前向传播的输入变量name。

举个例子:有两个输入变量:x1: name="x1",?x2: name="x2"

都放到刚才的程序计算并保存模型后,在加载模型中,我们只需要创建name="x1"的变量即可。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-17 22:08:21  更:2022-03-17 22:10:41 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 14:54:43-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码