我是靠谱客的博主 奋斗水蜜桃,最近开发中收集的这篇文章主要介绍Tensorflow: 动态的给变量tf.Variable赋值【tf.assign】,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

Motivation

错误:
tensorflow不能直接给Variable赋值,比如:

embedding_var = tf.Variable(1)
test_var = 10
embedding_var = test_var

或者:

embedding_var = tf.Variable(1)
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)
x.assign(1)

解决方法

正确:
如果只需要给Variable赋值一次,可以通过assign这样进行赋值:

import tensorflow as tf
x = tf.Variable(0)
y = tf.assign(x, 1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(x)
    print sess.run(y)
    print sess.run(x)

但是通常赋一次值的意义不大,因为有时我们想将网络中的一些输出通过saver()保存下来,或者通过tensorboard查看网络的embedding投影,那么就需要将网络中产生的输出以变量的形式储存,这样就可以在saver.save()的时候将这些输出给保存到本地,又因为tensorflow不能在图外面直接对变量进行操作,所以我通过用一个占位符来传输网络的输出结果,然后再session里面取出网络的输出值,feed给该占位符,然后将占位符的值赋给一个临时变量作为保存,如下,亲测有效:

flat_value = np.zeros([200,4*4*32]) 
mid_vari = tf.placeholder(tf.float32, [200,4*4*32],name="mid_vari")
embedding_var = tf.Variable(tf.zeros([200,4*4*32]), name=NAME_TO_VISUALISE_VARIABLE)
mid_vari_2 = tf.assign(embedding_var,mid_vari)

with tf.Session() as sess:
    saver =  tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    for i in range(200):
        flat_value,_=sess.run([flat,mid_vari_2],feed_dict={x:one_x,y:labels,mid_vari:flat_value})

比较周折,不过也是试了很多办法才找到的解决方案T_T。

参考

https://blog.csdn.net/mustar_2017/article/details/79336679

最后

以上就是奋斗水蜜桃为你收集整理的Tensorflow: 动态的给变量tf.Variable赋值【tf.assign】的全部内容,希望文章能够帮你解决Tensorflow: 动态的给变量tf.Variable赋值【tf.assign】所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(63)

评论列表共有 0 条评论

立即
投稿
返回
顶部