我是靠谱客的博主 朴素金针菇,最近开发中收集的这篇文章主要介绍TF2 build-in Keras在eager及非eager模式下callback训练过程中梯度的方式,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

Class Activation Map / Gradient Attention Map

分类/分割任务中可能会需要对训练过程中某些层的计算梯度进行操作,对于Keras来说我们可以通过使用Callback()实现返回梯度的目的,具体的例子如下所示,分为非eager模式和eager模式两部分。

1. 非eager模式

tf.compat.v1.disable_eager_execution()  # 这句一定要加上!

def get_gradient_func(model):
    ## If using 'tf.compat.v1.disable_eager_execution()'
    # 博主使用的模型有两个输出所以此处定义了两个gradients及function
    grads_1 = K.gradients(model.outputs[0], model.inputs[0])
    grads_2 = K.gradients(model.outputs[1], model.inputs[0])
    inputs = model._feed_inputs + model._feed_targets + model._feed_sample_weights
    func_1 = K.function(inputs, grads_1)
    func_2 = K.function(inputs, grads_2)
    return func_lipid, func_calcium

class CustomCallback(Callback):
    def __init__(self,model,training_generator,save_grad_path):
        self.model = model
        self.training_generator = training_generator
        self.save_grad_path = save_grad_path

    def on_epoch_end(self, epoch, logs=None):
        if (epoch+1)%10==0:
            epoch_gradient_1 = []
            epoch_gradient_2 = []
            get_gradient_1, get_gradient_2 = get_gradient_func(self.model)
            for step,batch in enumerate(self.training_generator):
                batch = tuple(t for t in batch)
                train_img = batch[0]
                train_label = batch[1]
                grads_1 = get_gradient_1([train_img, train_label, np.ones(16)]) # batchSize=16
                grads_2 = get_gradient_2([train_img, train_label, np.ones(16)])
                # 存储每个epoch output对input的梯度,下一个epoch时epoch_gradient变量会清空
                epoch_gradient_1.append(grads_1[0][:,:,:,3])
                epoch_gradient_2.append(grads_1[0][:,:,:,3])
           
        else:
            pass

2. eager模式

class CustomCallback(Callback):
    def __init__(self,model,training_generator,save_grad_path):
        self.model = model
        self.training_generator = training_generator
        self.save_grad_path = save_grad_path

    def on_epoch_end(self, epoch, logs=None):
        if (epoch+1)%10==0:
            epoch_gradient_1 = []
            epoch_gradient_2 = []
            input_layer = self.model.get_layer("data") # 模型的输入层'data',也可以是其他名字,根据model各层的起名来定
            # 由于是计算output对input的梯度,所以定义一个临时的模型用来进行out,data这两个tensor的输出
            # 若想计算Output关于其他层的梯度,只需要将input_layer.output替换为其他层的output即可
            temp_model = Model([self.model.inputs],[self.model.output,input_layer.output])
            for step,batch in enumerate(self.training_generator):
                batch = tuple(t for t in batch)
                train_img = batch[0]
                train_label = batch[1]
                # 默认的non-persisitent模式下,with tf.GradientTape() as gtape:一次只能使用gtape.gradient一次,连续使用会报错
                with tf.GradientTape() as gtape:
                    out, data = temp_model(train_img)
                    # 由于gtape只能跟踪trainable variants,而model的input是一个non-trainable的变量,所以要使用gtape.watch()进行追踪
                    gtape.watch(data) 
                grads_1 = gtape.gradient(out[0], data)
                with tf.GradientTape() as gtape:
                    out, data = temp_model(train_img)
                    gtape.watch(data)
                grads_2 = gtape.gradient(out[1], data)
                epoch_gradient_1.append(grads_1[:,:,:,3])
                epoch_gradient_2.append(grads_2[:,:,:,3])

        else:
            pass

个人推荐使用eager模式。

References:

https://stackoverflow.com/questions/58322147/how-to-generate-cnn-heatmaps-using-built-in-keras-in-tf2-0-tf-keras

https://stackoverflow.com/questions/61568665/tf2-compute-gradients-in-keras-callback-in-non-eager-mode
https://discuss.pytorch.org/t/generating-the-class-activation-maps/42887
https://www.tensorflow.org/api_docs/python/tf/GradientTape#gradient

最后

以上就是朴素金针菇为你收集整理的TF2 build-in Keras在eager及非eager模式下callback训练过程中梯度的方式的全部内容,希望文章能够帮你解决TF2 build-in Keras在eager及非eager模式下callback训练过程中梯度的方式所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(55)

评论列表共有 0 条评论

立即
投稿
返回
顶部