tensorflow实现segnet 和 learn deconv网络中的上采样。
def unpool_with_argmax(bottom, argmax, output_shape=None, name='max_unpool_with_argmax'):
upsampling according argmax
:param bottom: the output feature maps needed to be upsampled
:param argmax: the indice made by tf.nn.max_pool_with_argmax()
:param output_shape:
:param name:
with tf.name_scope(name):
ksize = [1, 2, 2, 1]
input_shape = bottom.get_shape().as_list()
# calculation new shape
if output_shape is None:
output_shape = (input_shape[0],
input_shape[1] * ksize[1],
input_shape[2] * ksize[2],
flat_input_size = np.prod(input_shape)
flat_output_size = np.prod(output_shape)
bottom_ = tf.reshape(bottom, [flat_input_size])
argmax_ = tf.reshape(argmax, [flat_input_size, 1])
ret = tf.scatter_nd(argmax_, bottom_, [flat_output_size])
ret = tf.reshape(ret, output_shape)
return ret
import numpy as np
import tensorflow as tf
input_data = tf.constant(np.random.rand(16, 4, 4, 3), dtype=np.float32)
x, arg = tf.nn.max_pool_with_argmax(input_data, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
def unpool_with_argmax(bottom, argmax, output_shape=None, name='max_unpool_with_argmax'):
upsampling according argmax
:param bottom: the output feature maps needed to be upsampled
:param argmax: the indice made by tf.nn.max_pool_with_argmax()
:param output_shape:
:param name:
with tf.name_scope(name):
ksize = [1, 2, 2, 1]
input_shape = bottom.get_shape().as_list()
# calculation new shape
if output_shape is None:
output_shape = (input_shape[0],
input_shape[1] * ksize[1],
input_shape[2] * ksize[2],
flat_input_size = np.prod(input_shape)
flat_output_size = np.prod(output_shape)
bottom_ = tf.reshape(bottom, [flat_input_size])
argmax_ = tf.reshape(argmax, [flat_input_size, 1])
ret = tf.scatter_nd(argmax_, bottom_, [flat_output_size])
ret = tf.reshape(ret, output_shape)
return ret
ret = unpool_with_argmax(x, arg)
x_2 = tf.nn.max_pool(ret, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
with tf.Session() as sess:
x_val, arg_val, ret_val, x_2_val = sess.run([x, arg, ret, x_2])
print x_val[0, :, :, 0]
print "#######################################"
print ret_val[0, :, :, 0]
print "**************************************"
print arg_val[0, :, :, 0]
print x_val.shape, arg_val.shape, ret_val.shape
[[ 0.92141378 0.83250898]
[ 0.96589577 0.92536974]]
[[ 0. 0. 0.83250898 0. ]
[ 0.92141378 0. 0. 0. ]
[ 0. 0.96589577 0. 0. ]
[ 0. 0. 0.92536974 0. ]]
[[12 6]
[27 42]]
(16, 2, 2, 3) (16, 2, 2, 3) (16, 4, 4, 3)
发表评论 取消回复