我是靠谱客的博主 任性画板,最近开发中收集的这篇文章主要介绍TensorFlow2.x冻结模型,保存为.pb格式方便部署,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

1、用tf.keras创建模型

inputs = tf.keras.Input(shape=(1000, 9,), name='input')
input = tf.keras.layers.Flatten(name="flatten")(inputs)
f1 = tf.keras.layers.Dense(512, activation='relu', name='dense_1')(input)
d1 = tf.keras.layers.Dropout(0.2,name='dropout_1')(f1)
f2 = tf.keras.layers.Dense(128, activation='relu', name='dense_2')(d1)
outputs = tf.keras.layers.Dense(3, activation='softmax', name='output')(f2)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name='intrusion')
model.summary()

2、模型训练

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['sparse_categorical_accuracy'])
history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=7, validation_split=0.2, validation_freq=1)

3、保存模型为.h5

h5_save_path = 'model.h5'
model.save(h5_save_path)

4、定义模型转化函数

def h5_to_pb(h5_save_path):
    model = tf.keras.models.load_model(h5_save_path, compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./frozen_models2",
                      name="model.pb",
                      as_text=False)

5、调用函数冻结模型为pb格式

h5_to_pb(h5_save_path)

模型输入输出名字在调用函数时可查看
在这里插入图片描述

这样就可以在其他平台部署模型了,而且不受tensroflow版本限制

最后

以上就是任性画板为你收集整理的TensorFlow2.x冻结模型,保存为.pb格式方便部署的全部内容,希望文章能够帮你解决TensorFlow2.x冻结模型,保存为.pb格式方便部署所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(54)

评论列表共有 0 条评论

立即
投稿
返回
顶部