首页 > 其他分享 >PyTorch笔记:hook的作用

PyTorch笔记:hook的作用

时间:2022-11-05 21:24:29浏览次数:70  
标签:10 layer self torch 笔记 hook PyTorch Size

参考自https://zhuanlan.zhihu.com/p/279903361,原始来自:https://towardsdatascience.com/how-to-use-pytorch-hooks-5041d777f904

在Module官方文档那片笔记中已经有一部分关于hook的介绍了,但是这里的更为具体,更能让我体会到hook的作用

1. 什么是钩子hook

所谓钩子就是:特定事件之后自动执行的函数。类似于回调函数

2. 钩子作用

2.1 显示模型执行详情,方便调试

首先创建一个包装类

class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
            )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

利用这个包装类,我们输出一个模型内部的结构。

import torch
from torchvision.models import resnet50

verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)

_ = verbose_resnet(dummy_input)

输出:

# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])
# fc: torch.Size([10, 1000])

2.2 特征提取

利用钩子,可以将另一个模型推理的中间结果(特征)输出,用于另一个别的模型使用。(非常适合搭积木!)

from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    # 这是一个产生产生函数的函数
    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self._features

使用时:

resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)

print({name: output.shape for name, output in features.items()})
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}

2.3 梯度裁剪

可以处理梯度爆炸,但是我还没接触过这些问题,可以参考原文

标签:10,layer,self,torch,笔记,hook,PyTorch,Size
From: https://www.cnblogs.com/x-ocean/p/16861317.html

相关文章

  • Extjs复习笔记(十七)-- 给grid里面的内容分组
    From: https://www.likecs.com/show-203524189.html 给grid里面的内容分组。 Ext.onReady(function(){Ext.QuickTips.init();//开启浮动汽泡提示功能var......
  • EXTJS学习笔记:grid之分组实现groupingview
    使用extjs开发时常会用到grid来显示数据等操作,Extjs中Grid主要分为以下二类:  一、gridview   二、groupingview   gridview在前面已说过,在这里我来说说groupin......
  • Linux学习笔记之常用命令——文件的基础操作篇
    stat查看inodels显示文件列表ls-a显示所有文件(包括隐藏文件)ll按照行数显示文件列表,相当于ls-lcd切换到某个指定路径.表示当前路径cd..返回上一级目录cd-......
  • PyTorch笔记:如何保存与加载checkpoints
    https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html保存和加载checkpoints很有帮助。为了保存checkpoints,必须将它们放在......
  • 笔记01--《可解释的机器学习》
    书籍来源:https://christophm.github.io/interpretable-ml-book/bike-data.html线性回归的解释-4.1.7稀疏线性模型 解释线性回归的模型,若是遇到特征较多的情况,可采用不......
  • PyTorch笔记:Python中的state_dict是啥
    来自:https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html在PyTorch中,可学习的参数都被保存在模型的parameters中,可以通过model.parameters()访问......
  • SpringBoot实战笔记:02_使用注解与Java配置的Aop示例
    转载:https://blog.csdn.net/android_zyf/article/details/79579875<!--02_新的依赖--><!--导入spring的aop支持--><dependency><groupId>${spring-groupId}</groupId>......
  • SpringBoot实战笔记:01_Spring中的Java配置
    转载:https://blog.csdn.net/android_zyf/article/details/79579862Spring4.x与SpringBoot都推荐使用Java配置xml配置:将bean的信息配置在xml配置文件中注解配置:在对应的bea......
  • VOLO论文笔记
    OutlookAttention设给定输入为\(X\inR^{H\timesW\timesC}\),首先经过两个线性映射得到两个输出A和V,A叫做outlookweight\(A\inR^{H\timesW\timesK^4}\)......
  • VIT论文笔记
    VITAnimageisworth16x16words:transformersforimagerecognitionatscale将transformer首次应用在视觉任务中,并取得了超过CNN方法的性能。标准的transformer接......