我是靠谱客的博主 傲娇草莓,最近开发中收集的这篇文章主要介绍tensorflow构建的ckpt文件转pb转onnx文件,深度学习模型推理时加速,以bert模型为例,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

在实际应用场景中,我们希望将训练好的模型转换为onnx格式,更便于后续的部署及推理时加速。

tensorflow模型需要将训练得到的ckpt文件转换为pb格式再转换为onnx格式。

这已经是本人一年多前的尝试了,可能现在这样做并不是最好的方式了。但是,无论学术还是工作中还是留有tensorflow训练的模型,希望转onnx加速一下拿来试试或者应付下手头的项目。因此,写了这篇博客。

pytorch训练得到的pt文件转onnx可参照下文:

pytorch构建的深度学习模型(pt文件)转换为onnx格式,并支持batch输入,以bert模型为例_记录与分享AI资料与学习过程-CSDN博客

1、环境配置

conda install cudatoolkit=10.1.243 cudnn=7.6.5
conda install tensorflow-gpu==1.14
pip install onnxruntime-gpu==1.4.0
pip install tqdm
pip install tf2onnx==1.6.2

2、在构建模型时对输入、输出指定相应的名字

需要对模型的输入输出取好名字,因为下一步在ckpt转pb文件

3、训练得到的ckpt文件转换成为pb文件

将训练得到的ckpt文件转换为pb文件,ckpt除了模型参数以外还含有模型训练相关的信息 如反向传播的梯度等,转换后的文件大小大致为之前的一半。

def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "output_length,pred_logits,transitions"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=False)
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants(
# 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
# for op in sess.graph.get_operations():
#
print(op.name, op.values())

 以上代码以序列标注任务为例,输出为output_node_names包含三部分,分别是句子长度、输出每个token在类别上的概率分布及概率。

下面测试一下转换后的pb模型,通过以下方法推理验证ckpt转换pb文件成功并正常预测

def freeze_graph_test(pb_path,input_ids,input_masks):
'''
:param pb_path:pb文件的路径
:param input_ids:bert token id
:param input_masks:0
:return:
'''
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 定义输入的张量名称,对应网络结构的输入张量
input_x = sess.graph.get_tensor_by_name("input_x_word:0")
input_mask = sess.graph.get_tensor_by_name("input_mask:0")
is_training = sess.graph.get_tensor_by_name('is_training:0')
# 定义输出的张量名称,对应网络结构的输出张量
logits = sess.graph.get_tensor_by_name('project/pred_logits:0')
lengths = sess.graph.get_tensor_by_name('output_length:0')
trans = sess.graph.get_tensor_by_name('transitions:0')
for i in range(0,2000):
logits1,lengths1,trans1 =sess.run([logits,lengths,trans], feed_dict=
{input_x: [input_ids[i]],
input_mask:[input_masks[i]],
is_training:False})

 主要就是指定好模型创建时候对应的输入和输出,然后用sess.run()指定好输入输出就能跑模型得到结果

4、通过tf2onnx工具将pb文件转为onnx文件

 由于版本的持续更新,可以安装更新的tf2onnx,前提查看版本算子的对应。tf2onnx工具详情可见一下链接(最新版本支持功能更加简便完善可以查看文档阅读):

GitHub - onnx/tensorflow-onnx: Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX

需要查询文档,看清不同版本不同opset下支持的算子:

tensorflow-onnx/support_status.md at master · onnx/tensorflow-onnx · GitHub

通过官方文档中所述方法转换将pb模型转换为onnx格式(本人采用tf2onnx 1.6.2,新版本指令略有变动可以自行查阅官方文档):

python -m tf2onnx.convert --input ./frozen_model.pb --inputs input_x_word:0,input_mask:0,is_training:0 --outputs project/pred_logits:0,output_length:0,transitions:0 --output ./bert6.onnx --verbose --opset=11

转换成功后,通过onnxruntime推理并验证模型正常预测:

import onnxruntime as ort
import numpy as np
import time
session = ort.InferenceSession("./bert.onnx")
is_training = session.get_inputs()[0].name
input_x_word = session.get_inputs()[1].name
input_mask = session.get_inputs()[2].name
logits = session.get_outputs()[0].name
lengths = session.get_outputs()[1].name
trans = session.get_outputs()[2].name
for i in range(0,2000):
logits1,lengths1,trans1 =session.run([logits,lengths,trans], {input_x_word: [input_ids[i]], input_mask:[input_masks[i]], is_training:[]})

最后

以上就是傲娇草莓为你收集整理的tensorflow构建的ckpt文件转pb转onnx文件,深度学习模型推理时加速,以bert模型为例的全部内容,希望文章能够帮你解决tensorflow构建的ckpt文件转pb转onnx文件,深度学习模型推理时加速,以bert模型为例所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(61)

评论列表共有 0 条评论

立即
投稿
返回
顶部