IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> Tensorflow 静态图PB模型修改(OP修改) -> 正文阅读

[Python知识库]Tensorflow 静态图PB模型修改(OP修改)

def load_pb_graph(path):
    with tf.gfile.GFile(path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as g:
        tf.import_graph_def(graph_def, name=None)
    return g


model_filename = '111.pb'
g = load_pb_graph(model_filename)
#加载原图完毕

new_model = tf.GraphDef()
with tf.Session(graph=g) as sess:
    for n in sess.graph_def.node:
        if n.name in ['import/input_ids','import/input_mask', 'import/token_type_ids']:
            nn = new_model.node.add()
            nn.op = n.op
            nn.name = n.name
            nn.attr['dtype'].CopyFrom(tf.AttrValue(type=tf.int32.as_datatype_enum))
            s = tensor_shape_pb2.TensorShapeProto()
            d1 = tensor_shape_pb2.TensorShapeProto.Dim()
            d2 = tensor_shape_pb2.TensorShapeProto.Dim()
            d1.size = 1
            d2.size = 7
            s.dim.extend([d1,d2])
            nn.attr['shape'].shape.CopyFrom(s)
            for i in n.input:
                nn.input.extend([i])
        else:
            new_model.node.append(n)
            # nn = new_model.node.add()
            # nn.CopyFrom(n) 太过于耗时,可以使用append直接加入old节点

print('*'*100)
#将新图注入到默认的Graph中
#tf.import_graph_def(new_model, name='')  # Imports `graph_def` into the current default `Graph`

# 测试案例
with tf.Session() as sess:
    tf.train.write_graph(new_model, logdir='./', name='graph_def_new.pb',  as_text=False)

pb输入会有import,ckpt就是node的名字。
这里将输入的128维度替换成7维,主要是为了测时间,如果是标准bert,后续还要修改reshape,这里对后续的所有操作融合成了一个OP,所以仅仅需要修改输入的dim即可。

最后保存pb时候, tf.train.write_graph,特别注意可能特别慢,因为需要把as_text设置成False,否则图被写成一个 text proto。

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-09-10 10:48:01  更:2021-09-10 10:48:18 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/15 13:22:59-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码