概述
1. 示例代码
#coding:utf-8
from tensorflow.keras import Input
from tensorflow.keras.layers import Dense, Lambda, Concatenate, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
# from tensorflow.keras.layers.embeddings import Embedding
import tensorflow as tf
import numpy as np
import os
import shutil
# https://www.jianshu.com/p/4e45c7c4eb43
def flinkModel():
#### 共享权重特征抽取
img = Input(shape=(512,), name='img')
time = Input(shape=(10,), name='time')
device = Input(shape=(10,), name='device')
date = Input(shape=(10,), name='date')
# image feature
img_feature = Dense(256, activation='relu')(img)
img_feature = Dense(128, activation='relu')(img_feature)
img_feature = Dense(64, activation='relu')(img_feature)
img_feature = Dense(8, activation='relu')(img_feature)
# time-spa feature
time_feature = Dense(8, activation='relu')(time)
device_feature = Dense(8, activation='relu')(device)
date_feature = Dense(8, activation='relu')(date)
feature = Concatenate(axis=1, name='feature')([img_feature, time_feature, device_feature, date_feature])
featureExtractModel = Model(inputs=[img, time, device, date], outputs=[feature])
img1 = Input(shape=(512,))
time1 = Input(shape=(10,))
device1 = Input(shape=(10,))
date1 = Input(shape=(10,))
img2 = Input(shape=(512,))
time2 = Input(shape=(10,))
device2 = Input(shape=(10,))
date2 = Input(shape=(10,))
img3 = Input(shape=(512,))
time3 = Input(shape=(10,))
device3 = Input(shape=(10,))
date3 = Input(shape=(10,))
feature1 = featureExtractModel([img1, time1, device1, date1])
feature2 = featureExtractModel([img2, time2, device2, date2])
feature3 = featureExtractModel([img3, time3, device3, date3])
merge_layers = Concatenate(name='merge_layers10_3')([feature1, feature2, feature3])
merge_layers = Dense(8, activation='relu', name='merge_layers8_3')(merge_layers)
# out = Dense(1, activation='sigmoid')(merge_layers)
out = Dense(4, name='y')(merge_layers)
class_models = Model(inputs=[
img1, time1, device1, date1,
img2, time2, device2, date2,
img3, time3, device3, date3],
# outputs=[feature1, feature2, feature3, out]
outputs=[out]
)
return class_models
x1_1 = np.random.random((1000, 512))
x2_1 = np.random.random((1000, 10))
x3_1 = np.random.random((1000, 10))
x4_1 = np.random.random((1000, 10))
x1_2 = np.random.random((1000, 512))
x2_2 = np.random.random((1000, 10))
x3_2 = np.random.random((1000, 10))
x4_2 = np.random.random((1000, 10))
x1_3 = np.random.random((1000, 512))
x2_3 = np.random.random((1000, 10))
x3_3 = np.random.random((1000, 10))
x4_3 = np.random.random((1000, 10))
y = np.random.randint(10, size=(1000, 4))
model = flinkModel()
# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
# 训练模型,并保存训练过程的细节
history = model.fit([x1_1, x2_1, x3_1, x4_1,
x1_2, x2_2, x3_2, x4_2,
x1_3, x2_3, x3_3, x4_3
], y, batch_size=32, epochs=10, verbose=1,
callbacks=None, validation_split=0.1,
validation_data=None, shuffle=True,
class_weight=None, sample_weight=None,
initial_epoch=0)
# 打印模型信息
model.summary()
# https://blog.csdn.net/leviopku/article/details/86310758
print('layers:')
for layer in model.layers:
print(layer.name)
print('inputs:')
for layer in model.inputs:
print(layer.name)
print('outputs:')
for layer in model.outputs:
print(layer.name)
# 重新构建中间层模型,输出中间结果
print('mid_model:')
# https://github.com/keras-team/keras/issues/13743
# mid_model = Model(input=model.input,
#
output=model.get_layer('merge_layers8_3').output)
mid_model = Model(model.input, model.get_layer('merge_layers8_3').output)
mid_ouput = mid_model.predict([x1_1, x2_1, x3_1, x4_1,
x1_2, x2_2, x3_2, x4_2,
x1_3, x2_3, x3_3, x4_3])[0]
print('mid_ouput:',mid_ouput)
# 保存模型
model_path = 'model'
# if not os.path.isdir(model_path):
#
os.mkdir(model_path)
if os.path.exists(model_path):
shutil.rmtree(model_path)
tf.saved_model.simple_save(
tf.keras.backend.get_session(),
model_path,
# inputs={"aaa_input": mid_model.input},
# outputs={"bbb": mid_model.output}
# inputs={"aaa_input": model.input},
#
outputs={"bbb": model.output}
inputs={t.name: t for t in mid_model.inputs},
outputs={t.name: t for t in mid_model.outputs}
)
print('Save trained model to {}'.format(model_path))
最后
以上就是腼腆小虾米为你收集整理的keras 共享参数层设计以及中间结果输出和模型结果保存的全部内容,希望文章能够帮你解决keras 共享参数层设计以及中间结果输出和模型结果保存所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复