我是靠谱客的博主 神勇水杯,最近开发中收集的这篇文章主要介绍TensorFlow Object Detection API 模型转换为 MNN (1),觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
文章目录
- 0. 前言
- 1. TFLite 生成流程
- 1.1 运行 `export_tflite_ssd_graph.py` 脚本
- 1.2 `tflite_convert` 工具使用
- 2. TensorFlow Object Detection API 模型转换脚本
- 2.1 过程
- 2.2 测试结果
- 附录
- TFLite 的输入与输出
0. 前言
-
最近在做目标检测模型的端侧部署,想把目前开源的一系列模型部署到 ARM CPU 下进行测试。
- 常见的目标检测端侧模型主要都是 yolo 系列和 ssd 系列。
- ssd 系列模型中,TensorFlow Object Detection API 应该是最好的项目了。
-
本文主要目标就是将 TensorFlow Object Detection API 中的模型部署到 ARM CPU 下。
- ARM CPU 部署主要使用了 NCNN 和 MNN。也许TFLite也行,但暂时先不考虑。
- NCNN 对 TensorFlow 支持非常不好,TF转ONNX转NCNN这条路走起来非常困难。
- MNN 支持的 Ops 比较多,这也成了最终的选择。
-
先说自己尝试 TensorFlow Object Detection API 到 MNN 的结论:
- 通过 pb 转 MNN 并不成功,原因不明。不成功有两种表现形式:
- 第一,通过
tools/script/fastTestTf.py
对 ssd 的 pb 文件进行测试,无法通过。 - 第二,pb 转 MNN 能够成功,但在后续执行
resizeSession
操作时会失败。
- 第一,通过
- 必须将 ssd 模型转换为 tflite,再转换为 MNN。
- 通过 pb 转 MNN 并不成功,原因不明。不成功有两种表现形式:
-
由于 TensorFlow Object Detection API 库分为两个部分,tf1.x 和 tf2.x,所以文章也分为两篇。本文主要介绍 tf1.x 相关模型的转换。
1. TFLite 生成流程
- TFLite 文件的生成分为两步
- 第一步,通过
export_tflite_ssd_graph.py
脚本将训练好的模型转换为 pb 格式 - 第二步,通过
tflite_convert
命令将 pb 转换为 tflite。
- 第一步,通过
- 官方资料
- 官方文档 - Running on mobile with TensorFlow Lite:介绍了详细流程
- export_tflite_ssd_graph.py 中的注释:介绍了一些转换中的细节
1.1 运行 export_tflite_ssd_graph.py
脚本
-
脚本输入:
- 模型配置文件。
- 模型训练权重文件,即
model.ckpt
。- 请注意,这个权重分为四个文件,
checkpoint/model.ckpt.meta/model.ckpt.index/model.ckpt.data
,分别记录了目录下模型文件列表、网络结构、参数名、参数数值。
- 请注意,这个权重分为四个文件,
-
脚本输出:在指定路径输出
tflite_graph.pbtxt
和tflite_graph.pb
两个文件。 -
脚本使用形式
- 注意,是在 research 目录下运行,而不是 object_detection 目录下运行
python object_detection/export_tflite_ssd_graph.py
--pipeline_config_path path/to/ssd_mobilenet.config
--trained_checkpoint_prefix path/to/model.ckpt
--output_directory path/to/exported_model_directory
1.2 tflite_convert
工具使用
- 需要安装
tflite_convert
工具。- 官方文档中说要源码安装,但我通过 conda 安装 tensorflow-gpu 后,这个命令也已经存在了。
- 命令的使用
- 注意,
inference_type
必须看情况使用
- 注意,
tflite_convert
--enable_v1_converter
--graph_def_file=tflite_graph.pb
--output_file=detect.tflite
--input_shapes=1,300,300,3
--input_arrays=normalized_input_image_tensor
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'
--inference_type=QUANTIZED_UINT8
--mean_values=128
--std_dev_values=128
--change_concat_input_ranges=false
--allow_custom_ops
2. TensorFlow Object Detection API 模型转换脚本
2.1 过程
-
脚本源码在这里
-
准备工作:
- 安装好 tf1.x 版的 TensorFlow Object Detection API
- 如果要转.mnn,必须先编译好 MNN
-
基本流程
- 第一步:从 TensorFlow Object Detection API Model Zoo 中下载权重文件,并解压
- 第二步:前一章介绍的两个工具,将 mode.ckpt 转换为 detect.tflite
- 第三步:通过 MNN 模型转换工具,将生成的 detect.tflite 转换为 .mnn 格式
-
一些注意事项:
- 如果要在转tflite的时候可以指定
--inference_type=QUANTIZED_UINT8
参数,那么转换前的 tensorflow 模型必须也是量化过的 - 注意模型输入尺寸(可以从
pipeline.config
中找到模型的尺寸)
- 如果要在转tflite的时候可以指定
2.2 测试结果
- 在 TensorFlow Object Detection API 1.x 版的所有 ssd 模型中:
model.ckpt
->tflite.pb
->detect.tflie
,这条路全部走得通(废话,这是官方支持的).tflite
->.mnn
的流程,除了ssdlite_mobilenet_edgetpu_coco_quant
外,其他都走得通。
- 使用
MNN/tools/scripts/fastTestTflite.py
验证 .tflite 模型- Model Zoo 中的 Mobile models 都通过,其他模型都不通过
- 错误类型类似于
TESTERROR TFLite_Detection_PostProcess value error : absMaxV:11.129623 - DiffMax 10.956316
- 感觉就是 argmax 相关,在实现 stdcnet 的时候也碰到一样的问题,如果去掉 argmax 就没问题。这个打算跟MNN团队反馈一下。
- 要解决估计就是输出的时候必须指定
add_postprocessing_op=false
,还没有具体测试过
附录
TFLite 的输入与输出
-
在
export_tflite_ssd_graph.py
脚本的注释中介绍了转换后的 ssd tflite 模型的输入与输出形式 -
输入:
- 如果是INT8量化后的模型,输入原始 uint8 形式的 RGB 图片即可。
- 如果是未量化的模型,需要将原始uint8形式、数据范围
[0, 255]
的输入转换到 float[-1, 1]
的形式,基本操作就是img / 128. - 1
- 从 TFLite 的角度,图片不需要resize。 当然,如果转换成MNN,可能还是要固定一下输入的尺寸。
-
输出:
-
如果指定了
add_postprocessing_op=true
(默认情况),则输出四个数据:detection_boxes
:数据类型 float32,shape为[1, num_boxes, 4]
detection_classes
:数据类型 float32,shape为[1, num_boxes]
detection_scores
:数据类型 float32,shape为[1, num_boxes]
num_boxes
:数据类型 float32,shape为[1]
-
如果
add_postprocessing_op=false
,那么有两个输出raw_outputs/box_encodings
:数据类型 float32,shape为[1, num_anchors, 4]
raw_outputs/class_predictions
:数据类型 float32,shape为[1, num_anchors, num_classes]
- 这部分没用过,没注意具体形式。
-
最后
以上就是神勇水杯为你收集整理的TensorFlow Object Detection API 模型转换为 MNN (1)的全部内容,希望文章能够帮你解决TensorFlow Object Detection API 模型转换为 MNN (1)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复