IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Tensorflow C++使用ops::BatchMatMul实现特征批量乘法 -> 正文阅读

[人工智能]Tensorflow C++使用ops::BatchMatMul实现特征批量乘法

本例主要测试Tensorflow C++ API中的ops::BatchMatMul算子。
整体来说这个算子比较简单。但是难在官网没有例子。Tensorflow的单测也写得不到位。
话不多说,上代码。
代码结构如下,
image.png

conanfile.txt

 [requires]
 gtest/1.10.0
 glog/0.4.0
 protobuf/3.9.1
 eigen/3.4.0
 dataframe/1.20.0
 opencv/3.4.17
 boost/1.76.0
 abseil/20210324.0
 xtensor/0.23.10

 [generators]
 cmake

CMakeLists.txt

cmake_minimum_required(VERSION 3.3)


project(test_math_ops)

set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")

set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)

include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()

find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)

set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})

set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})

set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)

include_directories(${INCLUDE_DIRS})


file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)

add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})

foreach( test_file ${test_file_list} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
    string(REPLACE ".cpp" "" file ${filename})
    add_executable(${file}  ${test_file})
    target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

tf_math2_test.cpp

#include <string>
#include <vector>
#include <glog/logging.h>
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"


using namespace tensorflow;

int main(int argc, char** argv) {
    FLAGS_log_dir = "./";
    FLAGS_alsologtostderr = true;
    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
    FLAGS_minloglevel = 0;

    Debug::DeathHandler dh;

    google::InitGoogleLogging("./logs.log");
    ::testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}


TEST(TfArthimaticTests, BatchMatMul) {
    // BatchMatMul  测试
    // Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
    
    // 2 * 1 * 2
    // 2 * 2 * 3
    // = 
    // 2 * 1 * 3
    Scope root = Scope::NewRootScope();
    auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 1, 2});
    /**
     * @brief Left param
     * {{1, 2},
     *  {3, 4}}
     */
    auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 2, 3});
    /**
     * @brief Right param
     *  {{{1, 2, 3}, {4, 1, 2}},
     *   {{3, 4, 5}, {6, 7, 8}}}
     */

    /**
     * @brief Result
     * {{9, 4, 7},
     *  {33, 40, 47}}
     */
    auto batch_op = ops::BatchMatMul(root, left_, right_);

    ClientSession session(root);
    std::vector<Tensor> outputs;
    session.Run({batch_op.output}, &outputs);

    test::PrintTensorValue<int>(std::cout, outputs[0]);
    test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({9, 4, 7, 33, 40, 47}, {2, 1, 3}));
}

TEST(TfArthimaticTests, BatchMatMulAdjXY) {
    // BatchMatMul  测试
    // Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
    
    // 2 * 1 * 2
    // 2 * 2 * 3
    // = 
    // 2 * 1 * 3
    Scope root = Scope::NewRootScope();
    auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 2, 1});
    /**
     * @brief Left param
     * {{{1}, 
     *   {2}},
     *  {{3},
     *   {4}}}
     */
    auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 3, 2});
    /**
     * @brief Right param
     *  {{{1, 2}, 
     *   {3, 4}, 
     *   {1, 2}}, 
     *
     *   {{3, 4}, 
     *   {5, 6}, 
     *   {7, 8}}  
     * }
     */
    
   
    /**
     * @brief Result
     * {{5, 11, 5},
     *  {25, 39, 53}}
     */

    auto attrs = ops::BatchMatMul::AdjX(true).AdjY(true);
    auto batch_op = ops::BatchMatMul(root, left_, right_, attrs);

    ClientSession session(root);
    std::vector<Tensor> outputs;
    session.Run({batch_op.output}, &outputs);

    test::PrintTensorValue<int>(std::cout, outputs[0]);
    test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({5, 11, 5, 25, 39, 53}, {2, 1, 3}));
}

程序输出如下,代表两个算子均测试通过。
image.png

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-26 11:41:49  更:2022-04-26 11:45:28 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 18:26:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码