我是靠谱客的博主 缥缈大地,最近开发中收集的这篇文章主要介绍pytorch模型转tflite【以EfficientNet-BTS为例】步骤配置环境Pytorch转onnxonnx模型测试onnx模型简化onnx转tensorflowtensorflow转tflite,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
步骤
使用pytorch转tflite需要经过:pytorch -> onnx -> tensorflow -> tflite
配置环境
# ONNX-TensorFlow: 1.8.0 [pip install onnx-tf==1.8.0]
# ONNX: 1.8.0 [pip install onnx==1.8.0]
## TensorFlow: 2.4.0 [pip install tensorflow==2.4.0]
# tf-nightly: 2.9.0-dev20220223 [pip install tf-nightly]
# PyTorch: 1.8.0 [pip install torch==1.8.0 ]
环境配置上的一些问题:
- 使用Tensorflow 2.4.0 会在onnx导出pb文件时报错,参考链接。应当使用tf-nightly。issue中推荐使用tf-nightly 2.4.0,测试发现使用最新版本2.9.0也可以解决问题。
- 使用Pytorch 1.7.0 时会出现Cat等冗余op维度不匹配的问题。导出的onnx模型无法正确inference。使用Pytorch1.8可以规避这个问题
- onnx与tf的版本对应可以参考链接。
Pytorch转onnx
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
from onnxsim import simplify
import onnxruntime as ort
import numpy as np
if __name__ == '__main__':
model = Model()
# Converting model to ONNX
for _ in model.modules():
_.training = False
test_arr = np.random.randn(1, 3, 480, 640).astype(np.float32)
sample_input = torch.tensor(test_arr)
# sample_input = torch.randn(1, 3, 480, 640)
input_nodes = ['input']
output_nodes = ['output']
model(sample_input)
torch.onnx.export(model, sample_input, "model.onnx", export_params=True, input_names=input_nodes,
output_names=output_nodes, opset_version=11)
- 此处注意opset_version=11,如果设置opset_version=10 / 9 会出现一些op不支持的问题,例如upsample_bilinear。
- 模型输入大小应当与原始模型输入大小一致,如果想动态适应,可以修改export中dynamic_axis参数
- Gpu应当设置为不可用,使得全部导出过程在CPU上运行。
onnx模型测试
model = onnx.load("model.onnx")
ort_session = ort.InferenceSession('model.onnx')
onnx_outputs = ort_session.run(None, {'input': test_arr})
print('Export ONNX!')
- 如果可以正常通过,证明onnx可以正确导出。
- 测试时可以和原模型输出对照一下,观察是否存在误差。
onnx模型简化
onnx_model = onnx.load("model.onnx")
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
- 模型简化使用的是onnx-simplify工具
- 模型简化可以去除一些在模型转化过程中产生的冗余Op,例如Concat / SUB
onnx转tensorflow
output = prepare(model_simp)
output.export_graph("tf_model/")
print('Export tf_model!')
- onnx转Tensorflow过程中可能会遇到一些Op无法转化的问题,例如interpolate函数,align_corners应当设置为True,然后重新导出onnx。参考链接
tensorflow转tflite
converter = tf.lite.TFLiteConverter.from_saved_model("tf_model")
tflite_model = converter.convert()
open("model.tflite", "wb").write(tflite_model)
print('Export tf lite model!')
- 转换时候可能会存在一些问题。安装tf-nightly可以解决。
Onnx和Tflite模型可以通过Netron工具可视化查看。
最后
以上就是缥缈大地为你收集整理的pytorch模型转tflite【以EfficientNet-BTS为例】步骤配置环境Pytorch转onnxonnx模型测试onnx模型简化onnx转tensorflowtensorflow转tflite的全部内容,希望文章能够帮你解决pytorch模型转tflite【以EfficientNet-BTS为例】步骤配置环境Pytorch转onnxonnx模型测试onnx模型简化onnx转tensorflowtensorflow转tflite所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复