概述
tensorflow实现segnet 和 learn deconv网络中的上采样。
def unpool_with_argmax(bottom, argmax, output_shape=None, name='max_unpool_with_argmax'):
'''
upsampling according argmax
:param bottom: the output feature maps needed to be upsampled
:param argmax: the indice made by tf.nn.max_pool_with_argmax()
:param output_shape:
:param name:
:return:
'''
with tf.name_scope(name):
ksize = [1, 2, 2, 1]
input_shape = bottom.get_shape().as_list()
# calculation new shape
if output_shape is None:
output_shape = (input_shape[0],
input_shape[1] * ksize[1],
input_shape[2] * ksize[2],
input_shape[3])
flat_input_size = np.prod(input_shape)
flat_output_size = np.prod(output_shape)
bottom_ = tf.reshape(bottom, [flat_input_size])
argmax_ = tf.reshape(argmax, [flat_input_size, 1])
ret = tf.scatter_nd(argmax_, bottom_, [flat_output_size])
ret = tf.reshape(ret, output_shape)
return ret
测试
import numpy as np
import tensorflow as tf
input_data = tf.constant(np.random.rand(16, 4, 4, 3), dtype=np.float32)
x, arg = tf.nn.max_pool_with_argmax(input_data, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
def unpool_with_argmax(bottom, argmax, output_shape=None, name='max_unpool_with_argmax'):
'''
upsampling according argmax
:param bottom: the output feature maps needed to be upsampled
:param argmax: the indice made by tf.nn.max_pool_with_argmax()
:param output_shape:
:param name:
:return:
'''
with tf.name_scope(name):
ksize = [1, 2, 2, 1]
input_shape = bottom.get_shape().as_list()
# calculation new shape
if output_shape is None:
output_shape = (input_shape[0],
input_shape[1] * ksize[1],
input_shape[2] * ksize[2],
input_shape[3])
flat_input_size = np.prod(input_shape)
flat_output_size = np.prod(output_shape)
bottom_ = tf.reshape(bottom, [flat_input_size])
argmax_ = tf.reshape(argmax, [flat_input_size, 1])
ret = tf.scatter_nd(argmax_, bottom_, [flat_output_size])
ret = tf.reshape(ret, output_shape)
return ret
ret = unpool_with_argmax(x, arg)
x_2 = tf.nn.max_pool(ret, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
with tf.Session() as sess:
x_val, arg_val, ret_val, x_2_val = sess.run([x, arg, ret, x_2])
print x_val[0, :, :, 0]
print "#######################################"
print ret_val[0, :, :, 0]
print "**************************************"
print arg_val[0, :, :, 0]
print x_val.shape, arg_val.shape, ret_val.shape
输出结果
[[ 0.92141378 0.83250898]
[ 0.96589577 0.92536974]]
#######################################
[[ 0. 0. 0.83250898 0. ]
[ 0.92141378 0. 0. 0. ]
[ 0. 0.96589577 0. 0. ]
[ 0. 0. 0.92536974 0. ]]
**************************************
[[12 6]
[27 42]]
(16, 2, 2, 3) (16, 2, 2, 3) (16, 4, 4, 3)
最后
以上就是内向小猫咪为你收集整理的tensorflow实现(indices)上采样的全部内容,希望文章能够帮你解决tensorflow实现(indices)上采样所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复