本例主要测试Tensorflow C++ API中的ops::BatchMatMul算子。 整体来说这个算子比较简单。但是难在官网没有例子。Tensorflow的单测也写得不到位。 话不多说,上代码。 代码结构如下,
conanfile.txt
[requires]
gtest/1.10.0
glog/0.4.0
protobuf/3.9.1
eigen/3.4.0
dataframe/1.20.0
opencv/3.4.17
boost/1.76.0
abseil/20210324.0
xtensor/0.23.10
[generators]
cmake
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)
project(test_math_ops)
set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)
set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})
set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})
set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)
include_directories(${INCLUDE_DIRS})
file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)
add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})
foreach( test_file ${test_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${test_file})
target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})
tf_math2_test.cpp
#include <string>
#include <vector>
#include <glog/logging.h>
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
using namespace tensorflow;
int main(int argc, char** argv) {
FLAGS_log_dir = "./";
FLAGS_alsologtostderr = true;
// 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
FLAGS_minloglevel = 0;
Debug::DeathHandler dh;
google::InitGoogleLogging("./logs.log");
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
TEST(TfArthimaticTests, BatchMatMul) {
// BatchMatMul 测试
// Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
// 2 * 1 * 2
// 2 * 2 * 3
// =
// 2 * 1 * 3
Scope root = Scope::NewRootScope();
auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 1, 2});
/**
* @brief Left param
* {{1, 2},
* {3, 4}}
*/
auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 2, 3});
/**
* @brief Right param
* {{{1, 2, 3}, {4, 1, 2}},
* {{3, 4, 5}, {6, 7, 8}}}
*/
/**
* @brief Result
* {{9, 4, 7},
* {33, 40, 47}}
*/
auto batch_op = ops::BatchMatMul(root, left_, right_);
ClientSession session(root);
std::vector<Tensor> outputs;
session.Run({batch_op.output}, &outputs);
test::PrintTensorValue<int>(std::cout, outputs[0]);
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({9, 4, 7, 33, 40, 47}, {2, 1, 3}));
}
TEST(TfArthimaticTests, BatchMatMulAdjXY) {
// BatchMatMul 测试
// Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
// 2 * 1 * 2
// 2 * 2 * 3
// =
// 2 * 1 * 3
Scope root = Scope::NewRootScope();
auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 2, 1});
/**
* @brief Left param
* {{{1},
* {2}},
* {{3},
* {4}}}
*/
auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 3, 2});
/**
* @brief Right param
* {{{1, 2},
* {3, 4},
* {1, 2}},
*
* {{3, 4},
* {5, 6},
* {7, 8}}
* }
*/
/**
* @brief Result
* {{5, 11, 5},
* {25, 39, 53}}
*/
auto attrs = ops::BatchMatMul::AdjX(true).AdjY(true);
auto batch_op = ops::BatchMatMul(root, left_, right_, attrs);
ClientSession session(root);
std::vector<Tensor> outputs;
session.Run({batch_op.output}, &outputs);
test::PrintTensorValue<int>(std::cout, outputs[0]);
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({5, 11, 5, 25, 39, 53}, {2, 1, 3}));
}
程序输出如下,代表两个算子均测试通过。
|