def load_pb_graph(path):
with tf.gfile.GFile(path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as g:
tf.import_graph_def(graph_def, name=None)
return g
model_filename = '111.pb'
g = load_pb_graph(model_filename)
#加载原图完毕
new_model = tf.GraphDef()
with tf.Session(graph=g) as sess:
for n in sess.graph_def.node:
if n.name in ['import/input_ids','import/input_mask', 'import/token_type_ids']:
nn = new_model.node.add()
nn.op = n.op
nn.name = n.name
nn.attr['dtype'].CopyFrom(tf.AttrValue(type=tf.int32.as_datatype_enum))
s = tensor_shape_pb2.TensorShapeProto()
d1 = tensor_shape_pb2.TensorShapeProto.Dim()
d2 = tensor_shape_pb2.TensorShapeProto.Dim()
d1.size = 1
d2.size = 7
s.dim.extend([d1,d2])
nn.attr['shape'].shape.CopyFrom(s)
for i in n.input:
nn.input.extend([i])
else:
new_model.node.append(n)
# nn = new_model.node.add()
# nn.CopyFrom(n) 太过于耗时,可以使用append直接加入old节点
print('*'*100)
#将新图注入到默认的Graph中
#tf.import_graph_def(new_model, name='') # Imports `graph_def` into the current default `Graph`
# 测试案例
with tf.Session() as sess:
tf.train.write_graph(new_model, logdir='./', name='graph_def_new.pb', as_text=False)
pb输入会有import,ckpt就是node的名字。 这里将输入的128维度替换成7维,主要是为了测时间,如果是标准bert,后续还要修改reshape,这里对后续的所有操作融合成了一个OP,所以仅仅需要修改输入的dim即可。
最后保存pb时候, tf.train.write_graph,特别注意可能特别慢,因为需要把as_text设置成False,否则图被写成一个 text proto。
|