概述
1.pytorch->onnx
try:
import onnx
print('nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pth', '.onnx').replace('.pt', '.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=10, input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output'])
# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
2.onnx->tf
from onnx_tf.backend import prepare
import onnx
import tensorflow as tf
onnx_model = onnx.load("model_vgg6_sim.onnx") # load onnx model
tf_rep = prepare(onnx_model) # prepare tf representation
tf_rep.export_graph("model_vgg6_sim.tf") # export the model
此处环境是:
tensorflow-cpu==2.6.0
onnx-tf==2.9.0
python==3.8
3.tf->tflite
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model('model_vgg6_sim.tf') # path to the SavedModel directory
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
4. tflite推理
# -*- coding: utf-8 -*-
# @Time : 2021/10/20 9:15
# @Author : jw hao
# @File : infrence_tflite.py
# @Software: PyCharm
# -*- coding:utf-8 -*-
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import cv2
import numpy as np
import time
from torchvision import datasets, models, transforms
from PIL import Image
import tensorflow as tf
test_image_dir = 'data/test/'
# model_path = "./model/quantize_frozen_graph.tflite"
model_path = "models/model.tflite"
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
print(str(input_details))
output_details = interpreter.get_output_details()
print(str(output_details))
data_transforms = transforms.Compose([
transforms.Resize(112),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])
# with tf.Session( ) as sess:
if 1:
file_list = os.listdir(test_image_dir)
model_interpreter_time = 0
start_time = time.time()
# 遍历文件
for file in file_list:
full_path = os.path.join(test_image_dir, file)
image = Image.open(full_path)
image = image.resize((112, 112))
image_np_expanded = data_transforms(image).unsqueeze(0)
# 填装数据
model_interpreter_start_time = time.time()
interpreter.set_tensor(input_details[0]['index'], image_np_expanded)
# 调用模型
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
model_interpreter_time += time.time() - model_interpreter_start_time
# 出来的结果去掉没用的维度
result = np.squeeze(output_data)
print('result:{}'.format(result))
used_time = time.time() - start_time
print('used_time:{}'.format(used_time))
print('model_interpreter_time:{}'.format(model_interpreter_time))
存在的问题:
1.目前代码转完的tflite输入是nchw。并不是nhwc。如果在cpu上部署可能会影响推理时间。
最后
以上就是包容黄蜂为你收集整理的pytorch->onnx->tf->tflite的全部内容,希望文章能够帮你解决pytorch->onnx->tf->tflite所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复