如果开发C++代码,链接pip安装的Tensorflow安装目录下面的so,会报如下错误:
E tensorflow/core/common_runtime/session.cc:67] Not found: No session factory registered for the given session options: {target: “” config: } Registered factories are {}.
同时会发现TensorFlow内部的算子都未注册,即使使用-Wl,–whole-archive处理也无法解决。
那么是否可以实现直接使用pip安装的tensorflow的so和头文件,实现C++接口调用推理呢?作者发现了一个方法并分享如下。
main.cpp推理代码example
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/env.h"
#include
#include
using namespace tensorflow;
#ifdef __cplusplus
extern "C" {
#endif
// instantiated in tensorflow _pywrap_tensorflow_internal.so
extern const char* TF_Version(void);
#ifdef __cplusplus
}
#endif
int main() {
// must be called to load op register
TF_Version();
std::string model_path = "resnet_50.pb";
tensorflow::GraphDef graphdef;
tensorflow::Status status_load = ReadBinaryProto(tensorflow::Env::Default(), model_path, &graphdef);
tensorflow::SessionOptions options;
tensorflow::Session* session;
session = tensorflow::NewSession(options);
if (session == nullptr) {
std::cout << "create new session failed" << std::endl;
return -1;
}
tensorflow::Status status;
status = session->Extend(graphdef);
if (!status.ok()) {
std::cout << "session extend graph failed" << std::endl;
return -1;
}
Tensor x(DT_FLOAT, TensorShape({1, 3, 224, 224}));
std::vector> input_tensors;
input_tensors.push_back({"input", x});
std::vector output_names = {"resnet_model/stage_1/Relu_2"};
std::vector outputs;
TF_CHECK_OK(session->Run(input_tensors, output_names, {}, &outputs));
// release session
session->Close();
delete session;
session = nullptr;
return 0;
}
这里的核心是调用了TF_Version();(可能其他函数也有类似功效) 从而成功加载so里面的符号,否则并不会加载。具体原因欢迎大家在评论区讨论。这个函数tf 2.x的pip安装包里已经提供了接口定义,而1.1x没有,需要手动定义下。
cmake文件编译选项
核心是需要包含python的so,tf的两个so
project(tf_cpp_test LANGUAGES CXX)
add_compile_options(-fPIC)
tf version >=1.15 use ABI=0
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_executable(
main
main.cpp
)
target_include_directories(
main
PUBLIC
$ENV{TF_INCLUDE_PATH}
$ENV{PYTHON_INCLUDE_PATH}
)
target_link_libraries(
main
PUBLIC
$ENV{TF_SO_FILE}
$ENV{TF_SO_PATH}/python/_pywrap_tensorflow_internal.so
$ENV{PYTHON_SO_FILE}
)
上面的TF_INCLUDE_PATH等可以通过bash脚本获取:
#!/bin/bash
TOOL_SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export TF_INCLUDE_PATH=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_compile_flags()[0].strip("-I"))')
export TF_SO_PATH=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_link_flags()[0].strip("-L"))')
export TF_SO_FILE=$(ls $TF_SO_PATH/libtensorflow_framework.* |head -1)
export PYTHON_INCLUDE_PATH=$(python3 -c 'import sysconfig; print(sysconfig.get_path("include"))')
export PYTHON_SO_PATH=$(python3 -c 'import sysconfig; print(sysconfig.get_path("stdlib"))')
export PYTHON_SO_FILE=$(find $PYTHON_SO_PATH/../ -name libpython3*.so|head -1)
mkdir ${TOOL_SCRIPT_DIR}/build
cd ${TOOL_SCRIPT_DIR}/build
cmake ..
make
上述代码测试环境:tf1.15+python3.7(基于conda虚拟环境)
Original: https://blog.csdn.net/u013701860/article/details/122241038
Author: Luchang-Li
Title: TensorFlow不重新编译源码使用C/C++ API推理
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/514756/
转载文章受原作者版权保护。转载请注明原作者出处!