参考自https://zhuanlan.zhihu.com/p/279903361,原始来自:https://towardsdatascience.com/how-to-use-pytorch-hooks-5041d777f904
在Module官方文档那片笔记中已经有一部分关于hook的介绍了,但是这里的更为具体,更能让我体会到hook的作用
1. 什么是钩子hook
所谓钩子就是:特定事件之后自动执行的函数。类似于回调函数
2. 钩子作用
2.1 显示模型执行详情,方便调试
首先创建一个包装类
class VerboseExecution(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
# Register a hook for each layer
for name, layer in self.model.named_children():
layer.__name__ = name
layer.register_forward_hook(
lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
)
def forward(self, x: Tensor) -> Tensor:
return self.model(x)
利用这个包装类,我们输出一个模型内部的结构。
import torch
from torchvision.models import resnet50
verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)
_ = verbose_resnet(dummy_input)
输出:
# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])
# fc: torch.Size([10, 1000])
2.2 特征提取
利用钩子,可以将另一个模型推理的中间结果(特征)输出,用于另一个别的模型使用。(非常适合搭积木!)
from typing import Dict, Iterable, Callable
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, layers: Iterable[str]):
super().__init__()
self.model = model
self.layers = layers
self._features = {layer: torch.empty(0) for layer in layers}
for layer_id in layers:
layer = dict([*self.model.named_modules()])[layer_id]
layer.register_forward_hook(self.save_outputs_hook(layer_id))
# 这是一个产生产生函数的函数
def save_outputs_hook(self, layer_id: str) -> Callable:
def fn(_, __, output):
self._features[layer_id] = output
return fn
def forward(self, x: Tensor) -> Dict[str, Tensor]:
_ = self.model(x)
return self._features
使用时:
resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)
print({name: output.shape for name, output in features.items()})
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}
2.3 梯度裁剪
可以处理梯度爆炸,但是我还没接触过这些问题,可以参考原文
标签:10,layer,self,torch,笔记,hook,PyTorch,Size From: https://www.cnblogs.com/x-ocean/p/16861317.html