概述
tf.get_default_graph().get_tensor_by_name("<name>:0")
1、基础知识
# 创建图
a = tf.constant([[1.0, 2.0], [3.0, 4.0]], name="a")
b = tf.constant([[1.0, 1.0], [0.0, 1.0]], name="b")
c = tf.matmul(a, b, name='example')
with tf.Session() as sess:
print(c.name)
# example:0
# <name>:0 (0 refers to endpoint which is somewhat redundant)
# 形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
tensor = tf.get_default_graph().get_tensor_by_name("example:0")
print(tensor)
# Tensor("example:0", shape=(2, 2), dtype=float32)
all_tensor = tf.get_default_graph().as_graph_def().node
print(all_tensor)
2、应用:
def load_model(model):
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model)
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
ckpt = tf.train.get_checkpoint_state(model_exp)
if ckpt and ckpt.model_checkpoint_path:
# 也可以使用 tf.train.import_meta_graph()
saver.restore(sess, ckpt.model_checkpoint_path)
with tf.Graph().as_default():
with tf.Session() as sess:
# Load the model
load_model(model_file)
# Get input and output tensors
images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
# Run forward pass to calculate embeddings
feed_dict = { images_placeholder: images, phase_train_placeholder:False }
emb = sess.run(embeddings, feed_dict=feed_dict)
最后
以上就是精明汽车为你收集整理的【TensorFlow学习笔记】get_tensor_by_name的全部内容,希望文章能够帮你解决【TensorFlow学习笔记】get_tensor_by_name所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复