我是靠谱客的博主 寒冷哑铃,最近开发中收集的这篇文章主要介绍课程笔记:HOOK函数,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

hook函数可以分为两部分:关于tensor(第一种)和关于module(第二三四种)
在这里插入图片描述

tensor.register_hook

在反向传播完成时,非叶子结点的梯度会消失
tensor.register_hook作用:
1)完成保存非叶子结点的梯度
2)修改叶子结点的值
在这里插入图片描述
例如:保存a的梯度值;修改w的梯度值
在这里插入图片描述

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    a_grad.append(grad)
# 通过a.register_hook保存a_grad(中间结点的梯度)值
handle = a.register_hook(grad_hook)

y.backward()
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    grad *= 2
    return grad*3
# 通过a.register_hook修改w的梯度值
handle = w.register_hook(grad_hook)

y.backward()

Module.register_forward_hook

获取forward种的feature map(是在某层conv执行之后再执行register_forward_hook)
在这里插入图片描述

def forward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)
# 注册hook
net.conv1.register_forward_hook(forward_hook)

Module.register_forward_pre_hook

在这里插入图片描述

def forward_pre_hook(module, data_input):
    print("forward_pre_hook input:{}".format(data_input))
# 注册hook
net.conv1.register_forward_pre_hook(forward_pre_hook)

Module.register_backward_hook

在这里插入图片描述

def backward_hook(module, grad_input, grad_output):
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))
# 注册hook
net.conv1.register_backward_hook(backward_hook)

整个过程下,register_forward_hook、register_forward_pre_hook、register_backward_hook的运行过程
可见,在模型实例化后再初始化再注册hook

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

def forward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)

def forward_pre_hook(module, data_input):
    print("forward_pre_hook input:{}".format(data_input))

def backward_hook(module, grad_input, grad_output):
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()

# 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
output = net(fake_img)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

小结

HOOK函数的2.3.4种(即关于module部分)在module中的call函数中执行的。
call函数种的顺序是:
forward_pre_hook
forward
forward_hook
backward_hook
在这里插入图片描述
由此可见,module中的call函数并不是只有forward函数,而是借助hook函数实现其他的功能

最后

以上就是寒冷哑铃为你收集整理的课程笔记:HOOK函数的全部内容,希望文章能够帮你解决课程笔记:HOOK函数所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(52)

评论列表共有 0 条评论

立即
投稿
返回
顶部