概述
hook函数可以分为两部分:关于tensor(第一种)和关于module(第二三四种)
tensor.register_hook
在反向传播完成时,非叶子结点的梯度会消失
tensor.register_hook作用:
1)完成保存非叶子结点的梯度
2)修改叶子结点的值
例如:保存a的梯度值;修改w的梯度值
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad):
a_grad.append(grad)
# 通过a.register_hook保存a_grad(中间结点的梯度)值
handle = a.register_hook(grad_hook)
y.backward()
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad):
grad *= 2
return grad*3
# 通过a.register_hook修改w的梯度值
handle = w.register_hook(grad_hook)
y.backward()
Module.register_forward_hook
获取forward种的feature map(是在某层conv执行之后再执行register_forward_hook)
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
# 注册hook
net.conv1.register_forward_hook(forward_hook)
Module.register_forward_pre_hook
def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input))
# 注册hook
net.conv1.register_forward_pre_hook(forward_pre_hook)
Module.register_backward_hook
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 注册hook
net.conv1.register_backward_hook(backward_hook)
整个过程下,register_forward_hook、register_forward_pre_hook、register_backward_hook的运行过程
可见,在模型实例化后再初始化再注册hook
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
小结
HOOK函数的2.3.4种(即关于module部分)在module中的call函数中执行的。
call函数种的顺序是:
forward_pre_hook
forward
forward_hook
backward_hook
由此可见,module中的call函数并不是只有forward函数,而是借助hook函数实现其他的功能
最后
以上就是寒冷哑铃为你收集整理的课程笔记:HOOK函数的全部内容,希望文章能够帮你解决课程笔记:HOOK函数所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复