概述
本篇博客记录笔者最近在在线推理服务中使用 Tensorflow C++ 接口的若干心得和疑(tu)惑(cao),整个流程包括创建 session ,加载 graph ,填充 tensor ,运行 session ,等等。注意,因为 tensorflow 2.0 没有普及,考虑稳定性,本篇博客代码均基于 tensorflow 1.12 。
1. session
1.1 session & client_session
我们知道 tensorflow 所有节点都处于 graph,而 graph 则和 session 绑定,所以线上的实时预测服务在初始化时需要创建 session 并载入 graph。网上找到的很多例子都是用 session
,而官网上只提供了 client_session
的接口,两者的主要区别在于 Run()
函数的参数不一样,client_session
如下:
Status Run(
const FeedType & inputs,
const std::vector< Output > & fetch_outputs,
const std::vector< Operation > & run_outputs,
std::vector< Tensor > *outputs
) const
这个 FeedType 定义如下:
typedef std::unordered_map<Output, Input::Initializer, OutputHash> FeedType;
对比 session
的 Run()
:
virtual Status Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs, RunMetadata* run_metadata);
看出区别没有?哈希表的 key 不同,一个是自定义类 Output,一个是 string。
我们要运行会话进行预测,不妨把模型当成一个黑盒子,那么关键的步骤有两步,喂数据和取结果,而喂数据要给不同的 placeholder
喂不同的数据,取结果则需要知道从哪个 operation
取结果,所以关键是要有一个哈希表记录 placeholder 或者 operation。client_session
是用 tensorflow c++ api 的 Placeholder 对象或者 Operation 对象作为哈希表的 key ,而 session
则是用 string。
所以,训练可以用 client_session
或者 session
,而预测只能用 session
。因为如果是训练的话,可以掉用 tf c++ api 创建 session
graph
placeholder
,可以得到 Placeholder
对象再 Run()
,但是预测过程是从模型文件中建立 graph,无法得到 placeholder 对象,所以也就无法使用 client_session
了。而 session
此时可以大展身手了,只需要训练方为需要输入和输出的节点命名,预测方就可以通过名称找到对应的节点,喂数据或者取结果就都可以进行了。
值得注意的是,client_session
也是通过封装 session
来实现的。所以为什么 tensorflow 官方文档只有 client_session
而没有 session
,实在令人困惑啊。
1.2 创建 session
tensorflow::NewSession()
可用于创建 session
tensorflow::SessionOptions options;
auto session = std::unique_ptr<tensorflow::Session>(tensorflow::NewSession(options));
2. load graph
关于图,模型文件在存储图时将图给“骨肉分离”了。骨为结构,存储节点与节点之间的连接,肉为数值,存储 variable 大小,有篇博客很好地解释了 tf 的图:Tensorflow框架实现中的“三”种图。有两种方式加载图。
方法一 session->Run()
执行 restore_op,代码如下:
tensorflow::MetaGraphDef graph_def;
std::string meta_graph_path = "model.meta";
auto status = ReadBinaryProto(tensorflow::Env::Default(), meta_graph_path, &graph_def);
if (!status.ok()) {
...
}
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
...
}
auto restore_op_name = graph_def.saver_def().restore_op_name();
auto filename_tensor_name = graph_def.saver_def().filename_tensor_name();
tensorflow::Input::Initializer filename({model_dir});
status = session->Run({{filename_tensor_name, filename.tensor}}, {}, {restore_op_name}, nullptr);
if (!status.ok()) {
...
}
方法二 tensorflow::SavedModelBundle
这种方法没有深究。
用哪种方式其实是由训练方导出模型的方式决定的。可以查看 SavedModelBundle
源码,其实和方法一类似,分别调用了 ReadBinaryProto()
session->Run()
,不同之处在于其读取的 pb 文件名称是固定的,“saved_model.pb”,也就是说方法二无法自定义 pb 名称。
3. initalize tensor
初始化 tensor 有三种方式,可以参考 stackoverflow 上的一个回答 How to fill a tensor in C++
3.1 tensorflow::Input::Initializer
一种方法是使用 tensorflow::Input::Initializer
,用法如下:
tensorflow::Input::Initializer x0_index({0, 0, 1, 1}, tensorflow::TensorShape({2, 2}));
不过使用过程中出现过一个报错:
Invalid argument: Expects arg[0] to be int64 but int32 is provided
这是因为 Initalizer 用模板传参,不能正确识别类型(我需要 int64,但是模板识别为 int32),导致在使用 tensor 时报错.
官方 api 可以看到
tensorflow::Input::Initializer::Initializer(
const std::initializer_list< T > & v,
const TensorShape & shape
)
解决方法是,每个数据加上 LL 后缀
3.2 x.tensor<>()() = XX
另一种方法是逐个赋值:
tensorflow::Tensor x0(tensorflow::DT_FLOAT, tensorflow::TensorShape({2,2}));
x0.tensor<float, 2>()(0,0) = 1;
x0.tensor<float, 2>()(0,1) = 2;
x0.tensor<float, 2>()(1,0) = 2;
x0.tensor<float, 2>()(1,1) = 3;
逐个赋值看似更麻烦,但是在遍历一个容器(例如 vector )填充 tensor 时更方便,这是因为 sd::initializer_list
只支持初始化列表的初始化方式,无法用其它容器的迭代器初始化,而且初始化后也不能 push,所以也不能写个 for 循环喂数据给它,总之很不好用。tensorflow 用sd::initializer_list
这个容器应该是出于性能考虑,但是很不方便,不过不用担心,还有第三种方法。
中间还遇到一个数据类型问题,就是编译如下的代码:
tensorflow::Tensor x_index(tensorflow::DT_INT64, tensorflow::TensorShape({100, 100, 2}));
for (int64_t i = 0; i < 100; ++i) {
for (int64_t j = 0; j < 100; ++j) {
x_index.tensor<int64_t, 3>()(i, j, 0) = i;
x_index.tensor<int64_t, 3>()(i, j, 1) = j;
}
}
报错:
tensorflow/core/framework/types.h:357:3: error: static assertion failed: Specified Data Type not supported
查看源文件,发现问题所在,tensorflow 会将 基本数据类型(例如 float) 转成 tensorflow 自定义数据类型(例如 DT_FLOAT),而 int64_t 不在其转换范围内。解决方法是,要么使用 long long int ,要么使用 tensorflow 自定义的 int64
,如下:
x_index.tensor<long long int, 3>()(i, j, 0) = i;
x_index.tensor<tensorflow::int64, 3>()(i, j, 0) = i;
3.3 flat()
第三种方式 flat()
std::copy_n(x_indices.begin(), x_indices.size(), x_index.flat<tensorflow::int64>().data());
个人觉得这种方式最优,先把所有元素按顺序放进随便一个只要不是 sd::initializer_list
的容器中(按顺序是指从低维到高维每个维度依次遍历,例如二维矩阵按照从左到右从上到下的顺序),然后再灌进 tensor。flat()
这个函数很形象,把高维的 tensor 拍扁了,拍成一维数组,这样 fill tensor 就方便多了,不用像 3.2 那样再考虑某个元素放在第几维的位置。
4. sparse tensor
为什么上面代码的 tensor 变量名都是 x_index 呢,其实都是为了构建稀疏张量做准备的。稀疏张量只有很少的元素非零,所以只需标示出非零元素的位置和值。一个 sparse tensor 由 x_index, x_value, x_shape 三部分构成,x_index 表示 tensor 非零元素的位置,x_value 表示非零元素的值,x_shape 表示 tensor 的形状。构建一个 sparse tensor 代码如下:
tensorflow::Input::Initializer x0_index({0LL, 0LL, 1LL, 1LL}, tensorflow::TensorShape({2, 2}));
tensorflow::Input::Initializer x0_value({1.0f, 2.0f}, tensorflow::TensorShape({2}));
tensorflow::TensorShape x0_shape({2, 2});
tensorflow::sparse::SparseTensor x0_sparse_tensor(x0_index.tensor, x0_value.tensor, x0_shape);
构建好了 SparseTensor 后,就可以使用了,怎么用呢?嗯,没法用。是的,Run()
没有 SparseTensor 作为输入参数的重载函数。查看 tensorflow/core/public/session.h:
virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) = 0;
Run()
只支持类型为 Tensor 的输入,所以 tensorflow c++ 提供 SparseTensor 接口是在逗你玩?只能看看不能用?
如果要使用 SparseTensor ,有一种解决方法,就是将 feed_dict 里输入张量 从一个 placeholder
改为 x_index, x_value, x_shape 三个 placeholder
,然后再以此生成 SparseTensor,python 代码如下:
x0_index = tf.placeholder(tf.int64, name = 'x0_index')
x0_value = tf.placeholder(tf.float32, name = 'x0_value')
x0_shape = tf.placeholder(tf.int64, name = 'x0_shape')
x0 = tf.SparseTensor(x0_index, x0_value, x0_shape)
这样 C++ 端的代码只需要喂这三个非稀疏 Tensor 就行了。
最后
以上就是活泼镜子为你收集整理的Tensorflow C++ api 笔记1. session2. load graph3. initalize tensor4. sparse tensor的全部内容,希望文章能够帮你解决Tensorflow C++ api 笔记1. session2. load graph3. initalize tensor4. sparse tensor所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复