IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> C++知识库 -> C++ TensorflowLite模型验证 -> 正文阅读

[C++知识库]C++ TensorflowLite模型验证

故事是这样的:

有一个手撑检测的tflite模型,需要在开发板上跑起来。手机版本的已成熟,要移植到开发板上。现在要验证tflite模型文件在板子上的运行结果要和手机上一致。

前提:为了多次重复测试,在Android端使用了同一帧数据(从一个录制的mp4中固定取一张图)测试代码如下图

?

下面是测试过程?

  1. 记录下Android版API运行推理前的图片数据文件(经过了规一化处理,所以都是-1~1之间的float数据)
    1. 这一步卡在了写float数据到二进制文件中,C++读出来有问题
    2. 换了个方案,直接存储float字符串
      private void saveFile(float[] pfImageData) {
              try {
                  File file = new File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS).getAbsolutePath() + "/tfimg");
      
                  StringBuilder sb = new StringBuilder();
                  for (float val : pfImageData) {
                      //保留4位小数,这里可以改为其他值
                      sb.append(String.format("%.4f", val));
                      sb.append("\r\n");
                  }
      
                  FileWriter out = new FileWriter(file);  //文件写入流
                  out.write(sb.toString());
                  out.close();
              } catch (Exception e) {
                  e.printStackTrace();
                  Log.e("Melon", "存储文件异常," + e.getMessage());
              }
          }

  2. 拿着这个文件在板子上输入到Tflite模型中

    1. 测试代码,主要是RunInference()和read_file()
      /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      
      Licensed under the Apache License, Version 2.0 (the "License");
      you may not use this file except in compliance with the License.
      You may obtain a copy of the License at
      
          http://www.apache.org/licenses/LICENSE-2.0
      
      Unless required by applicable law or agreed to in writing, software
      distributed under the License is distributed on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      See the License for the specific language governing permissions and
      limitations under the License.
      ==============================================================================*/
      
      #include "tensorflow/lite/examples/label_image/label_image.h"
      
      #include <fcntl.h>     // NOLINT(build/include_order)
      #include <getopt.h>    // NOLINT(build/include_order)
      #include <sys/time.h>  // NOLINT(build/include_order)
      #include <sys/types.h> // NOLINT(build/include_order)
      #include <sys/uio.h>   // NOLINT(build/include_order)
      #include <unistd.h>    // NOLINT(build/include_order)
      
      #include <cstdarg>
      #include <cstdio>
      #include <cstdlib>
      #include <fstream>
      #include <iomanip>
      #include <iostream>
      #include <map>
      #include <memory>
      #include <sstream>
      #include <string>
      #include <unordered_set>
      #include <vector>
      
      #include "absl/memory/memory.h"
      #include "tensorflow/lite/examples/label_image/bitmap_helpers.h"
      #include "tensorflow/lite/examples/label_image/get_top_n.h"
      #include "tensorflow/lite/examples/label_image/log.h"
      #include "tensorflow/lite/kernels/register.h"
      #include "tensorflow/lite/optional_debug_tools.h"
      #include "tensorflow/lite/profiling/profiler.h"
      #include "tensorflow/lite/string_util.h"
      #include "tensorflow/lite/tools/command_line_flags.h"
      #include "tensorflow/lite/tools/delegates/delegate_provider.h"
      
      namespace tflite
      {
        namespace label_image
        {
      
          double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
      
          using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
          using ProvidedDelegateList = tflite::tools::ProvidedDelegateList;
      
          class DelegateProviders
          {
          public:
            DelegateProviders() : delegate_list_util_(&params_)
            {
              delegate_list_util_.AddAllDelegateParams();
            }
      
            // Initialize delegate-related parameters from parsing command line arguments,
            // and remove the matching arguments from (*argc, argv). Returns true if all
            // recognized arg values are parsed correctly.
            bool InitFromCmdlineArgs(int *argc, const char **argv)
            {
              std::vector<tflite::Flag> flags;
              // delegate_list_util_.AppendCmdlineFlags(&flags);
      
              const bool parse_result = Flags::Parse(argc, argv, flags);
              if (!parse_result)
              {
                std::string usage = Flags::Usage(argv[0], flags);
                LOG(ERROR) << usage;
              }
              return parse_result;
            }
      
            // According to passed-in settings `s`, this function sets corresponding
            // parameters that are defined by various delegate execution providers. See
            // lite/tools/delegates/README.md for the full list of parameters defined.
            void MergeSettingsIntoParams(const Settings &s)
            {
              // Parse settings related to GPU delegate.
              // Note that GPU delegate does support OpenCL. 'gl_backend' was introduced
              // when the GPU delegate only supports OpenGL. Therefore, we consider
              // setting 'gl_backend' to true means using the GPU delegate.
              if (s.gl_backend)
              {
                if (!params_.HasParam("use_gpu"))
                {
                  LOG(WARN) << "GPU deleate execution provider isn't linked or GPU "
                               "delegate isn't supported on the platform!";
                }
                else
                {
                  params_.Set<bool>("use_gpu", true);
                  // The parameter "gpu_inference_for_sustained_speed" isn't available for
                  // iOS devices.
                  if (params_.HasParam("gpu_inference_for_sustained_speed"))
                  {
                    params_.Set<bool>("gpu_inference_for_sustained_speed", true);
                  }
                  params_.Set<bool>("gpu_precision_loss_allowed", s.allow_fp16);
                }
              }
      
              // Parse settings related to NNAPI delegate.
              if (s.accel)
              {
                if (!params_.HasParam("use_nnapi"))
                {
                  LOG(WARN) << "NNAPI deleate execution provider isn't linked or NNAPI "
                               "delegate isn't supported on the platform!";
                }
                else
                {
                  params_.Set<bool>("use_nnapi", true);
                  params_.Set<bool>("nnapi_allow_fp16", s.allow_fp16);
                }
              }
      
              // Parse settings related to Hexagon delegate.
              if (s.hexagon_delegate)
              {
                if (!params_.HasParam("use_hexagon"))
                {
                  LOG(WARN) << "Hexagon deleate execution provider isn't linked or "
                               "Hexagon delegate isn't supported on the platform!";
                }
                else
                {
                  params_.Set<bool>("use_hexagon", true);
                  params_.Set<bool>("hexagon_profiling", s.profiling);
                }
              }
      
              // Parse settings related to XNNPACK delegate.
              if (s.xnnpack_delegate)
              {
                if (!params_.HasParam("use_xnnpack"))
                {
                  LOG(WARN) << "XNNPACK deleate execution provider isn't linked or "
                               "XNNPACK delegate isn't supported on the platform!";
                }
                else
                {
                  params_.Set<bool>("use_xnnpack", true);
                  params_.Set<bool>("num_threads", s.number_of_threads);
                }
              }
            }
      
            // Create a list of TfLite delegates based on what have been initialized (i.e.
            // 'params_').
            std::vector<ProvidedDelegateList::ProvidedDelegate> CreateAllDelegates()
                const
            {
              return delegate_list_util_.CreateAllRankedDelegates();
            }
      
          private:
            // Contain delegate-related parameters that are initialized from command-line
            // flags.
            tflite::tools::ToolParams params_;
      
            // A helper to create TfLite delegates.
            ProvidedDelegateList delegate_list_util_;
          };
      
          // Takes a file name, and loads a list of labels from it, one per line, and
          // returns a vector of the strings. It pads with empty strings so the length
          // of the result is a multiple of 16, because our model expects that.
      
          // std::vector<uint8_t> read_file(const std::string &input_bmp_name)
          // {
          //   int begin, end;
      
          //   std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
          //   if (!file)
          //   {
          //     LOG(FATAL) << "input file " << input_bmp_name << " not found";
          //     exit(-1);
          //   }
      
          //   begin = file.tellg();
          //   file.seekg(0, std::ios::end);
          //   end = file.tellg();
          //   size_t len = end - begin;
      
          //   LOG(INFO) << "len: " << len;
          //   std::vector<uint8_t> img_bytes(len);
      
          //   file.seekg(0, std::ios::beg);
          //   file.read(reinterpret_cast<char *>(img_bytes.data()), len);
      
          //   return img_bytes;
          // }
      
          /**
           * 读取文件
           */
          std::vector<float> read_file(const std::string &input_bmp_name)
          {
            int begin, end;
      
            std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
            if (!file)
            {
              LOG(FATAL) << "input file " << input_bmp_name << " not found";
              exit(-1);
            }
      
            begin = file.tellg();
            file.seekg(0, std::ios::end);
            end = file.tellg();
            size_t len = end - begin;
      
            LOG(INFO) << "len: " << len;
            std::vector<float> img_bytes;
      
            file.seekg(0, std::ios::beg);
      
            string strLine = "";
            float temp;
            while (getline(file, strLine))
            {
              temp = atof(strLine.c_str());
              img_bytes.push_back(temp);
            }
      
            LOG(INFO) << "文件读取完成:" << input_bmp_name;
            return img_bytes;
          }
      
          /**
           * 运行推理
           */
          void RunInference(Settings *settings)
          {
            if (!settings->model_name.c_str())
            {
              LOG(ERROR) << "no model file name";
              exit(-1);
            }
      
            std::unique_ptr<tflite::FlatBufferModel> model;
            std::unique_ptr<tflite::Interpreter> interpreter;
            model = tflite::FlatBufferModel::BuildFromFile(settings->model_name.c_str());
            if (!model)
            {
              LOG(ERROR) << "Failed to mmap model " << settings->model_name;
              exit(-1);
            }
            settings->model = model.get();
            LOG(INFO) << "Loaded model " << settings->model_name;
            model->error_reporter();
            LOG(INFO) << "resolved reporter";
      
            tflite::ops::builtin::BuiltinOpResolver resolver;
      
            tflite::InterpreterBuilder(*model, resolver)(&interpreter); //生成interpreter
            if (!interpreter)
            {
              LOG(ERROR) << "Failed to construct interpreter";
              exit(-1);
            }
      
            interpreter->SetAllowFp16PrecisionForFp32(settings->allow_fp16);
      
            if (settings->verbose)
            {
              LOG(INFO) << "tensors size: " << interpreter->tensors_size();
              LOG(INFO) << "nodes size: " << interpreter->nodes_size();
              LOG(INFO) << "inputs: " << interpreter->inputs().size();
              LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0);
      
              int t_size = interpreter->tensors_size();
              for (int i = 0; i < t_size; i++)
              {
                if (interpreter->tensor(i)->name)
                  LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                            << interpreter->tensor(i)->bytes << ", "
                            << interpreter->tensor(i)->type << ", "
                            << interpreter->tensor(i)->params.scale << ", "
                            << interpreter->tensor(i)->params.zero_point;
              }
            }
      
            if (settings->number_of_threads != -1)
            {
              interpreter->SetNumThreads(settings->number_of_threads);
            }
      
            int image_width = 128;
            int image_height = 128;
            int image_channels = 3;
            // std::vector<uint8_t> in = read_bmp(settings->input_bmp_name, &image_width, &image_height, &image_channels, settings);
            std::vector<float> file_bytes = read_file(settings->input_bmp_name);
            for (int i = 0; i < 100; i++)
            {
              //和Android的输入做对比
              LOG(INFO) << i << ": " << file_bytes[i];
            }
      
            /* inputs()[0]得到输入张量数组中的第一个张量,也就是classifier中唯一的那个输入张量;
            input是个整型值,是张量列表中的引索 */
            int input = interpreter->inputs()[0];
            LOG(INFO) << "input: " << input;
      
            const std::vector<int> inputs = interpreter->inputs();
            const std::vector<int> outputs = interpreter->outputs();
      
            LOG(INFO) << "number of inputs: " << inputs.size();
            LOG(INFO) << "input index: " << inputs[0];
            LOG(INFO) << "number of outputs: " << outputs.size();
            LOG(INFO) << "outputs index1: " << outputs[0] << ",outputs index2: " << outputs[1];
      
            if (interpreter->AllocateTensors() != kTfLiteOk)
            { //加载所有tensor
              LOG(ERROR) << "Failed to allocate tensors!";
              exit(-1);
            }
      
            if (settings->verbose)
              PrintInterpreterState(interpreter.get());
      
            // 从输入张量的原数据中得到输入尺寸
            TfLiteIntArray *dims = interpreter->tensor(input)->dims;
            int wanted_height = dims->data[1];
            int wanted_width = dims->data[2];
            int wanted_channels = dims->data[3];
      
            settings->input_type = interpreter->tensor(input)->type;
      
            //typed_tensor返回一个经过固定数据类型转换的tensor指针
            //以input为索引,在TfLiteTensor* content_.tensors这个张量表得到具体的张量
            //返回该张量的data.raw,它指示张量正关联着的内存块
            // resize<float>(interpreter->typed_tensor<float>(input), in.data(),
            //               image_height, image_width, image_channels, wanted_height,
            //               wanted_width, wanted_channels, settings);
      
            //赋值给input tensor
            float *inputP = interpreter->typed_input_tensor<float>(0);
      
            LOG(INFO) << "file_bytes size: " << file_bytes.size();
            for (int i = 0; i < file_bytes.size(); i++)
            {
              inputP[i] = file_bytes[i];
            }
      
            struct timeval start_time, stop_time;
            gettimeofday(&start_time, nullptr);
            for (int i = 0; i < settings->loop_count; i++)
            { //调用模型进行推理
              if (interpreter->Invoke() != kTfLiteOk)
              {
                LOG(ERROR) << "Failed to invoke tflite!";
                exit(-1);
              }
            }
            gettimeofday(&stop_time, nullptr);
            LOG(INFO) << "invoked";
            LOG(INFO) << "average time: "
                      << (get_us(stop_time) - get_us(start_time)) /
                             (settings->loop_count * 1000)
                      << " ms";
      
            const float threshold = 0.001f;
      
            int output = interpreter->outputs()[1];
            LOG(INFO) << "output: " << output;
            LOG(INFO) << "interpreter->tensors_size: " << interpreter->tensors_size();
      
            TfLiteTensor *tensor = interpreter->tensor(output);
      
            TfLiteIntArray *output_dims = tensor->dims;
            // assume output dims to be something like (1, 1, ... ,size)
            auto output_size = output_dims->data[output_dims->size - 1];
            LOG(INFO) << "索引为" << output << "的输出张量的-"
                      << "output_size: " << output_size;
      
            for (int i = 0; i < output_dims->size; i++)
            {
              LOG(INFO) << "元数据有:" << output_dims->data[i];
            }
      
            float *prediction = interpreter->typed_output_tensor<float>(1);
      
            float classificators[1][896][1];
            memcpy(classificators, prediction, 896 * 1 * sizeof(float));
            // float classificators[1][896][18];
            // memcpy(classificators, prediction, 896 * 18 * sizeof(float));
      
            //输出分类结果
            for (float(&r)[896][1] : classificators)
            {
              for (float(&p)[1] : r)
              {
                for (float &q : p)
                {
                  std::cout << q << ' ';
                }
                std::cout << std::endl;
              }
              std::cout << std::endl;
            }
          }
      
          int Main(int argc, char **argv)
          {
            DelegateProviders delegate_providers;
            bool parse_result = delegate_providers.InitFromCmdlineArgs(
                &argc, const_cast<const char **>(argv));
            if (!parse_result)
            {
              return EXIT_FAILURE;
            }
      
            Settings s;
      
            int c;
            while (true)
            {
              static struct option long_options[] = {
                  {"accelerated", required_argument, nullptr, 'a'},
                  {"allow_fp16", required_argument, nullptr, 'f'},
                  {"count", required_argument, nullptr, 'c'},
                  {"verbose", required_argument, nullptr, 'v'},
                  {"image", required_argument, nullptr, 'i'},
                  {"labels", required_argument, nullptr, 'l'},
                  {"tflite_model", required_argument, nullptr, 'm'},
                  {"profiling", required_argument, nullptr, 'p'},
                  {"threads", required_argument, nullptr, 't'},
                  {"input_mean", required_argument, nullptr, 'b'},
                  {"input_std", required_argument, nullptr, 's'},
                  {"num_results", required_argument, nullptr, 'r'},
                  {"max_profiling_buffer_entries", required_argument, nullptr, 'e'},
                  {"warmup_runs", required_argument, nullptr, 'w'},
                  {"gl_backend", required_argument, nullptr, 'g'},
                  {"hexagon_delegate", required_argument, nullptr, 'j'},
                  {"xnnpack_delegate", required_argument, nullptr, 'x'},
                  {nullptr, 0, nullptr, 0}};
      
              /* getopt_long stores the option index here. */
              int option_index = 0;
      
              c = getopt_long(argc, argv,
                              "a:b:c:d:e:f:g:i:j:l:m:p:r:s:t:v:w:x:", long_options,
                              &option_index);
      
              /* Detect the end of the options. */
              if (c == -1)
                break;
      
              switch (c)
              {
              case 'a':
                s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'b':
                s.input_mean = strtod(optarg, nullptr);
                break;
              case 'c':
                s.loop_count =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'e':
                s.max_profiling_buffer_entries =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'f':
                s.allow_fp16 =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'g':
                s.gl_backend =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'i':
                s.input_bmp_name = optarg;
                break;
              case 'j':
                s.hexagon_delegate = optarg;
                break;
              case 'l':
                s.labels_file_name = optarg;
                break;
              case 'm':
                s.model_name = optarg;
                break;
              case 'p':
                s.profiling =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'r':
                s.number_of_results =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 's':
                s.input_std = strtod(optarg, nullptr);
                break;
              case 't':
                s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn)
                    optarg, nullptr, 10);
                break;
              case 'v':
                s.verbose =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'w':
                s.number_of_warmup_runs =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'x':
                s.xnnpack_delegate =
                    strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
                break;
              case 'h':
              case '?':
                /* getopt_long already printed an error message. */
                exit(-1);
              default:
                exit(-1);
              }
            }
      
            delegate_providers.MergeSettingsIntoParams(s);
            RunInference(&s);
            return 0;
          }
      
        } // namespace label_image
      } // namespace tflite
      
      int main(int argc, char **argv)
      {
        return tflite::label_image::Main(argc, argv);
      }
      

    2. 运行指令?./ws_app --tflite_model libnewpalm_detection.tflite --image tfimg
  3. 对比推理前的输入一致

    1. Android端
    2. 开发板上
  4. 对比推理后的输出一致
    1. Android端
    2. 开发板端

  C++知识库 最新文章
【C++】友元、嵌套类、异常、RTTI、类型转换
通讯录的思路与实现(C语言)
C++PrimerPlus 第七章 函数-C++的编程模块(
Problem C: 算法9-9~9-12:平衡二叉树的基本
MSVC C++ UTF-8编程
C++进阶 多态原理
简单string类c++实现
我的年度总结
【C语言】以深厚地基筑伟岸高楼-基础篇(六
c语言常见错误合集
上一篇文章      下一篇文章      查看所有文章
加:2021-08-26 11:56:22  更:2021-08-26 11:59:11 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/23 16:49:31-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码