IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 记录::Opencv调用tensorflow2.x-Keras模型 -> 正文阅读

[Python知识库]记录::Opencv调用tensorflow2.x-Keras模型

需要用C++调用tensoeflow模型,但我发现现在的tensorflow2.x的版本都是用keras搭建的,不想用动态库,决定直接用Opencv调用模型。

库版本:

tensorflow 2.2.0

opencv 4.2.0.32

参考:OpenCV使用Tensorflow2-Keras模型_风翼冰舟的博客-CSDN博客_opencv调用keras模型

tensorflow?Frozen-Graph-TensorFlow/TensorFlow_v2 at master · leimao/Frozen-Graph-TensorFlow · GitHub

主要保存模型部分

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./frozen_models",
                      name="simple_frozen_graph.pb",
                      as_text=False)

用的github里面的例子1测试

训练模型,用的mnist数据集,下载数据集部分,如果报错(url无效什么的)可以手动下载后放在C:\Users\Administrator\.keras\datasets\fashion-mnist 里面

def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph

    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))

?训练以及保存模型

def trainmodel2():
    tf.random.set_seed(seed=0)

    # Get data
    (train_images, train_labels), (test_images,
                                   test_labels) = get_fashion_mnist_data()

    # Create Keras model
    model = keras.Sequential(layers=[
        keras.layers.InputLayer(input_shape=(28, 28), name="input"),
        keras.layers.Flatten(input_shape=(28, 28), name="flatten"),
        keras.layers.Dense(128, activation="relu", name="dense"),
        keras.layers.Dense(10, activation="softmax", name="output")
    ], name="FCN")

    # Print model architecture
    model.summary()

    # Compile model with optimizer
    model.compile(optimizer="adam",
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])

    # Train model
    model.fit(x={"input": train_images}, y={"output": train_labels}, epochs=1)


    # Save model to SavedModel format
    tf.saved_model.save(model, "./frozen_models/simple_model")
    #tf.model.save((model, "./frozen_models/simple_model"))
    # Convert Keras model to ConcreteFunction
    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
  
    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./frozen_models",
                      name="simple_frozen_graph.pb",
                      as_text=False)

tensorflow调用测试

def tftest():
    # Load frozen graph using TensorFlow 1.x functions
    with tf.io.gfile.GFile("./frozen_models/simple_frozen_graph.pb", "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        loaded = graph_def.ParseFromString(f.read())

    # Wrap frozen graph to ConcreteFunctions
    frozen_func = wrap_frozen_graph(graph_def=graph_def,
                                    inputs=["x:0"],
                                    outputs=["Identity:0"],
                                    print_graph=True)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)
    
    #调用测试
    test_x = cv2.imread("1.png",0)
    test_x=cv2.resize(test_x,(28,28))
    pred_y = frozen_func(x=tf.constant(test_x,dtype=tf.float32))[0]
    print(pred_y[0].numpy()) 

python-opencv调用测试

def opencvtest():
    test_x = cv2.imread("1.png",0)
    test_x = cv2.dnn.blobFromImage(image=test_x, scalefactor=1.0, size=(28, 28))
    net = cv2.dnn.readNetFromTensorflow("./frozen_models/simple_frozen_graph.pb")
    net.setInput(test_x)
    pred = net.forward()
    print(pred)

c++opencv调用测试

int main() {
	Mat test_x = imread("1.png", 0);
	test_x = cv::dnn::blobFromImage(test_x,1.0,Size(28, 28));
	dnn::Net net = cv::dnn::readNetFromTensorflow("simple_frozen_graph.pb");
	net.setInput(test_x);
	Mat pred = net.forward();
	cout << pred << endl; 

	return 0;
}

更多完整代码参考:

GitHub - ziyaoma/Opencv-Tensorflow2.x: 用opencv调用tensorflow2.x的keras训练的模型

问题:

用tensorflow2.6版本出现opencv调动报错,

error:?(-2:Unspecified?error)?Can't?create?layer?"NoOp"?of?type?"NoOp"?in?function?'cv::dnn::dnn4_v20191202::LayerData::getLayerInstance'

解决:

估计是版本不匹配吧,降到了2.2就可以了,

?

对比发现输出多了一层,不知道怎么解决,有知道的大佬求告知

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-04-24 09:24:12  更:2022-04-24 09:25:59 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 10:35:11-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计