概述
Motivation
错误:
tensorflow不能直接给Variable赋值,比如:
embedding_var = tf.Variable(1)
test_var = 10
embedding_var = test_var
或者:
embedding_var = tf.Variable(1)
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)
x.assign(1)
解决方法
正确:
如果只需要给Variable赋值一次,可以通过assign这样进行赋值:
import tensorflow as tf
x = tf.Variable(0)
y = tf.assign(x, 1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(x)
print sess.run(y)
print sess.run(x)
但是通常赋一次值的意义不大,因为有时我们想将网络中的一些输出通过saver()保存下来,或者通过tensorboard查看网络的embedding投影,那么就需要将网络中产生的输出以变量的形式储存,这样就可以在saver.save()的时候将这些输出给保存到本地,又因为tensorflow不能在图外面直接对变量进行操作,所以我通过用一个占位符来传输网络的输出结果,然后再session里面取出网络的输出值,feed给该占位符,然后将占位符的值赋给一个临时变量作为保存,如下,亲测有效:
flat_value = np.zeros([200,4*4*32])
mid_vari = tf.placeholder(tf.float32, [200,4*4*32],name="mid_vari")
embedding_var = tf.Variable(tf.zeros([200,4*4*32]), name=NAME_TO_VISUALISE_VARIABLE)
mid_vari_2 = tf.assign(embedding_var,mid_vari)
with tf.Session() as sess:
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
for i in range(200):
flat_value,_=sess.run([flat,mid_vari_2],feed_dict={x:one_x,y:labels,mid_vari:flat_value})
比较周折,不过也是试了很多办法才找到的解决方案T_T。
参考
https://blog.csdn.net/mustar_2017/article/details/79336679
最后
以上就是奋斗水蜜桃为你收集整理的Tensorflow: 动态的给变量tf.Variable赋值【tf.assign】的全部内容,希望文章能够帮你解决Tensorflow: 动态的给变量tf.Variable赋值【tf.assign】所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复