我是靠谱客的博主 内向小猫咪,最近开发中收集的这篇文章主要介绍tensorflow实现(indices)上采样,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

tensorflow实现segnet 和 learn deconv网络中的上采样。

def unpool_with_argmax(bottom, argmax, output_shape=None, name='max_unpool_with_argmax'):
    '''
    upsampling according argmax
    :param bottom: the output feature maps needed to be upsampled
    :param argmax: the indice made by tf.nn.max_pool_with_argmax()
    :param output_shape: 
    :param name:
    :return:
    '''
    with tf.name_scope(name):
        ksize = [1, 2, 2, 1]
        input_shape = bottom.get_shape().as_list()
        #  calculation new shape
        if output_shape is None:
            output_shape = (input_shape[0],
                            input_shape[1] * ksize[1],
                            input_shape[2] * ksize[2],
                            input_shape[3])
        flat_input_size = np.prod(input_shape)
        flat_output_size = np.prod(output_shape)
        bottom_ = tf.reshape(bottom, [flat_input_size])
        argmax_ = tf.reshape(argmax, [flat_input_size, 1])

        ret = tf.scatter_nd(argmax_, bottom_, [flat_output_size])

        ret = tf.reshape(ret, output_shape)
        return ret

测试

import numpy as np
import tensorflow as tf
input_data = tf.constant(np.random.rand(16, 4, 4, 3), dtype=np.float32)

x, arg = tf.nn.max_pool_with_argmax(input_data, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

def unpool_with_argmax(bottom, argmax, output_shape=None, name='max_unpool_with_argmax'):
    '''
    upsampling according argmax
    :param bottom: the output feature maps needed to be upsampled
    :param argmax: the indice made by tf.nn.max_pool_with_argmax()
    :param output_shape:
    :param name:
    :return:
    '''
    with tf.name_scope(name):
        ksize = [1, 2, 2, 1]
        input_shape = bottom.get_shape().as_list()
        #  calculation new shape
        if output_shape is None:
            output_shape = (input_shape[0],
                            input_shape[1] * ksize[1],
                            input_shape[2] * ksize[2],
                            input_shape[3])
        flat_input_size = np.prod(input_shape)
        flat_output_size = np.prod(output_shape)
        bottom_ = tf.reshape(bottom, [flat_input_size])
        argmax_ = tf.reshape(argmax, [flat_input_size, 1])

        ret = tf.scatter_nd(argmax_, bottom_, [flat_output_size])

        ret = tf.reshape(ret, output_shape)
        return ret


ret = unpool_with_argmax(x, arg)

x_2 = tf.nn.max_pool(ret, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

with tf.Session() as sess:
    x_val, arg_val, ret_val, x_2_val = sess.run([x, arg, ret, x_2])
    print x_val[0, :, :, 0]
    print "#######################################"
    print ret_val[0, :, :, 0]
    print "**************************************"
    print arg_val[0, :, :, 0]
    print x_val.shape, arg_val.shape, ret_val.shape


输出结果

[[ 0.92141378  0.83250898]
 [ 0.96589577  0.92536974]]
#######################################
[[ 0.          0.          0.83250898  0.        ]
 [ 0.92141378  0.          0.          0.        ]
 [ 0.          0.96589577  0.          0.        ]
 [ 0.          0.          0.92536974  0.        ]]
**************************************
[[12  6]
 [27 42]]
(16, 2, 2, 3) (16, 2, 2, 3) (16, 4, 4, 3)

最后

以上就是内向小猫咪为你收集整理的tensorflow实现(indices)上采样的全部内容,希望文章能够帮你解决tensorflow实现(indices)上采样所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(47)

评论列表共有 0 条评论

立即
投稿
返回
顶部