概述
tf.argmax(data, axis=None)
用tensorflow 做 mnist分类时,用到这个接口,于是就研究了下这个接口的用法:
如果是一维数组呢?
data = tf.constant([1,2,3])
with tf.Session() as sess:
print(sess.run(tf.argmax(data, 0)))
#轴默认为0
print(sess.run(tf.argmax(data)))
>>> 2
>>> 2
这个很好理解,因为data是一维数组,axis只能为0(如果是1就会报错),结果返回数组中最大值的下标,所以是2
如果是二维数组呢?
data = tf.constant([[1,2,3]])
with tf.Session() as sess:
print(sess.run(tf.argmax(data, 0)))
print(sess.run(tf.argmax(data, 1)))
>>> [0,0,0]
>>> [2]
是不是有点晕了?
我是这么理解的:
Axis = 0时:
只有data[0] = [1,2,3], 按照对应位置比较,因为只有data[0],1的对应位置为空,所以1是最大值,2的对应位置为空所以2是最大值,3的对应位置为空所以3是最大值。
而argmax函数返回的是最大值的索引,因为1, 2,3 都属于data[0],所以返回值是 [0, 0, 0].
Axis = 1时:
Data[0][0] = 1
Data[0][1] = 2
Data[0][2] = 3
1和2 和 3比较显然是 3最大,3的索引为2,所以返回[2]
再看一个二维数组,可能就明白了:
data = tf.constant([[1,2,3], [4,5,6]])
with tf.Session() as sess:
print(sess.run(tf.argmax(data, 0)))
print(sess.run(tf.argmax(data, 1)))
>>> [1, 1, 1]
>>> [2, 2]
Axis = 0时:
Data[0] = [1,2,3]
Data[1] = [4,5,6]
对应位置比较:4 > 1, 5>2, 6>3, 所以返回 4,5,6所在的索引位置[1,1,1]
Axis = 1时:
Data[0][0] = 1
Data[0][1] = 2
Data[0][2] = 3
对应位置比较 3最大,3的索引为2
Data[1][0] = 4
Data[1][1] = 5
Data[1][2] = 6
对应位置比较6最大,6的索引为2
所以最后返回[2,2].
同样如果是三维数组:
data = tf.constant([[[1,2,3]],
[[7, 1,9]]])
with tf.Session() as sess:
print(sess.run(tf.argmax(data, 0))) # [[1 0 1]]
print(sess.run(tf.argmax(data, 1))) # [[0,0,0],[0,0,0]]
print(sess.run(tf.argmax(data, 2))) # [[2],[2]]
同样步骤分析:
Axis = 0时:
Data[0] = [[1, 2, 3]]
Data[1] = [[7, 1, 9]]
对应位置比较 7>1, 2 >1, 9> 3, 7属于索引1,2属于索引0,9属于索引1,所以返回[[1, 0,1]].
Axis = 1时:
Data[0][0] = [1,2,3]
1 2 3,对应位置分别为空,所以1,2,3在对应位置都是最大,1,2,3,都属于索引为0,返回[0,0,0]。
Data[1][0] = [7, 1,9]
7 1 9,对应位置分别为空,所以7,1,9在对应位置都是最大,7,1,9,都属于索引为0,返回[0,0,0]
所以最后返回[[0,0,0],[0,0,0]]。
Axis = 2时:
Data[0][0][0] = 1
Data[0][0][1] = 2
Data[0][0][2] = 3
3比较最大,3所在的索引为2,返回 2,
Data[1][0][0] = 7
Data[1][0][1] = 1
Data[1][0][2] = 9
9 最大,9所在的索引为2,返回2
所以最后返回[[2],[2]].
如果是四维或者更高维度,都是按照同样的方法。
最后
以上就是平常电灯胆为你收集整理的tf.argmax()的详细用法的全部内容,希望文章能够帮你解决tf.argmax()的详细用法所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复