概述
引言:
当大家在使用tf.assign()这个函数时,如果不是很了解这个函数的用法,很容易出错,而且似乎对应不同的tf版本其操作结果也会有细微的差别,本文是基于1.9.0版本的tf
进行描述的,对于更新的版本而言应该结论是一样的,但对于比较旧的版本,可能就会有细微差别。
首先我们看一下源码中的返回值说明:
update = tf.assign(ref, new_value) # 平时的使用写法
--------------------------------------------------------------------
Returns:
A `Tensor` that will hold the new value of 'ref' after
the assignment has completed.
也就是说,只有当这个赋值被完成时,该旧值ref
才会被修改成new_value
。不过这样描述还是太抽象了,那到底什么叫赋值被完成呢?下面我给大家放两个简单的例子,帮助大家理解
import tensorflow as tf
ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(ref_sum))
------------------------------------------------------
输出结果:3
然后你就会感到奇怪,这里与往常的直觉不一样,理论上ref_a
应该已经被修改为10了?带着疑问,我们看第二个例子
import tensorflow as tf
ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(update) # 唯一修改的地方
print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12
看到这里,是不是大家就明白了。所谓的赋值被完成其实指得是需要对tf.assign()
函数的返回值执行一下sess.run()
操作后,才能保证正常更新。
在明白了这个易错的地方后,我再介绍两种方法,来达到同样的目的。
方法一:采用ref_a = tf.assign(ref_a, 10)
操作,我们看一下代码和运行结果
import tensorflow as tf
ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
ref_a = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12
事实上,tf.assign(ref, new_value)
函数返回的结果就是参数中的new_value
,因此我们只需要用ref
来接收返回值也可以达到直接更新的效果
方法二:使用tf.control_dependencies()
函数,我们也同样来看一下代码和结果
import tensorflow as tf
ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)
with tf.control_dependencies([update]):
ref_sum = tf.add(ref_a, ref_b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12
可以发现,结果也为我们预期想要达到的效果,该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作完成后再进行。简单地说,就是实际在运行时,会先执行该函数传递的参数update
,再执行其辖域中的操作ref_sum = tf.add(ref_a, ref_b)
如果觉得我有地方讲的不好的或者有错误的欢迎给我留言,谢谢大家阅读(点个赞我可是会很开心的哦)~
最后
以上就是帅气夏天为你收集整理的Tensorflow:tf.assign()函数的使用方法及易错点的全部内容,希望文章能够帮你解决Tensorflow:tf.assign()函数的使用方法及易错点所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复