首页 > 其他分享 >取出预训练模型中间层的输出(pytorch)

取出预训练模型中间层的输出(pytorch)

时间:2023-03-12 09:55:39浏览次数:46  
标签:layers 输出 return name nn hook pytorch 中间层 model

1 遍历子模块直接提取

对于简单的模型,可以采用直接遍历子模块的方法,取出相应name模块的输出,不对模型做任何改动。该方法的缺点在于,只能得到其子模块的输出,而对于使用nn.Sequensial()中包含很多层的模型,无法获得其指定层的输出

示例 resnet18取出layer1的输出

from torchvision.models import resnet18
import torch

model = resnet18(pretrained=True)
print("model:", model)
out = []
x = torch.randn(1, 3, 224, 224)
return_layer = "layer1"
for name, module in model.named_children():
    x = module(x)
    if name == return_layer:
        out.append(x.data)
        break
print(out[0].shape)  # torch.Size([1, 64, 56, 56])

2 IntermediateLayerGetter类

torchvison中提供了IntermediateLayerGetter类,该方法同样只能得到其子模块的输出,而对于使用nn.Sequensial()中包含很多层的模型,无法获得其指定层的输出

from torchvision.models._utils import IntermediateLayerGetter

IntermediateLayerGetter类的pytorch源码

class IntermediateLayerGetter(nn.ModuleDict):
    """
    Module wrapper that returns intermediate layers from a model

    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.

    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.

    Args:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
    """
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}

        # 重新构建backbone,将没有使用到的模块全部删掉
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

示例 使用IntermediateLayerGetter类 改 resnet34+unet 完整代码见gitee

import torch
from torchvision.models import resnet18, vgg16_bn, resnet34
from torchvision.models._utils import IntermediateLayerGetter

model = resnet34()
stage_indices = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
return_layers = dict([(str(j), f"stage{i}") for i, j in enumerate(stage_indices)])
model= IntermediateLayerGetter(model, return_layers=return_layers)
input = torch.randn(1, 3, 224, 224)
output = model(input)
print([(k, v.shape) for k, v in output.items()])

3 create_feature_extractor函数

使用create_feature_extractor方法,创建一个新的模块,该模块将给定模型中的中间节点作为字典返回,用户指定的键作为字符串,请求的输出作为值。该方法比 IntermediateLayerGetter方法更通用, 不局限于获得模型第一层子模块的输出。比如下面的vgg,池化层都在子模块feature中,上面的方法无法取出,因此推荐使用create_feature_extractor方法。

示例 FCN论文中以vgg为backbone,分别取出三个池化层的输出

import torch
from torchvision.models import vgg16_bn
from torchvision.models.feature_extraction import create_feature_extractor

model = vgg16_bn()
model = create_feature_extractor(model, {"features.43": "pool5", "features.33": "pool4", "features.23": "pool3"})
input = torch.randn(1, 3, 224, 224)
output = model(input)
print([(k, v.shape) for k, v in output.items()])

4 hook函数

  hook函数是程序中预定义好的函数,这个函数处于原有程序流程当中(暴露一个钩子出来)。我们需要再在有流程中钩子定义的函数块中实现某个具体的细节,需要把我们的实现,挂接或者注册(register)到钩子里,使得hook函数对目标可用。hook 是一种编程机制,和具体的语言没有直接的关系。

  Pytorch的hook编程可以在不改变网络结构的基础上有效获取、改变模型中间变量以及梯度等信息。在pytorch中,Module对象有register_forward_hook(hook) 和 register_backward_hook(hook) 两种方法,两个的操作对象都是nn.Module类,如神经网络中的卷积层(nn.Conv2d),全连接层(nn.Linear),池化层(nn.MaxPool2d, nn.AvgPool2d),激活层(nn.ReLU)或者nn.Sequential定义的小模块等。register_forward_hook是获取前向传播的输出的,即特征图或激活值register_backward_hook是获取反向传播的输出的,即梯度值。(这边只讲register_forward_hook,其余见链接

示例 获取resnet18的avgpool层的输入输出

import torch
from torchvision.models import resnet18

model = resnet18()
fmap_block = dict()  # 装feature map
def forward_hook(module, input, output):
    fmap_block['input'] = input
    fmap_block['output'] = output

layer_name = 'avgpool'
for (name, module) in model.named_modules():
    if name == layer_name:
        module.register_forward_hook(hook=forward_hook)

input = torch.randn(64, 3, 224, 224)
output = model(input)
print(fmap_block['input'][0].shape)
print(fmap_block['output'].shape)

  

参考

1. Pytorch提取预训练模型特定中间层的输出

2. Pytorch的hook技术——获取预训练/已训练好模型的特定中间层输出

 

标签:layers,输出,return,name,nn,hook,pytorch,中间层,model
From: https://www.cnblogs.com/Fish0403/p/17141048.html

相关文章

  • 倒序输出升级版
    数字反转(升级版)题目描述给定一个数,请将该数各个位上数字反转得到一个新数。这次与NOIp2011普及组第一题不同的是:这个数可以是小数,分数,百分数,整数。整数反转是将所......
  • 遭遇奇怪的问题:所有 ASP.NET Core ViewComponent 都输出为空
    3月9日晚上的一次发布中遇到一个非常奇怪的问题,发布前在staging环境测试正常,发布到生产环境后发现所有ViewComponent都输出为空(没有任何内容)。生产环境与staging环......
  • pytorch MRI脑瘤检测
    读取数据#readingtheimagestumor=[]path='D:\\data\\Tumor_detection\\archive\\brain_tumor_dataset\\yes\\*.jpg'#*表示所有forfinglob.iglob(path):#遍历......
  • Jupyter Notebook 输出有颜色的文字
    JupyterNotebook支持HTML语言,所以在Markdowm模式下,直接输入HTML代码即可得到想要的。无论是图片还是文字大小颜色等。<fontcolor="red">红色的文字</font>下面是转来的......
  • HJ58 输入n个整数,输出其中最小的k个
    描述输入n个整数,找出其中最小的k个整数并按升序输出 本题有多组输入样例数据范围:1\len\le1000\1≤n≤1000  ,输入的整数满足 1\leval\le10000\1≤val≤......
  • 在JSP页面对一组数据进行判断之后输出
    需求:在合同信息管理页面做一个临期提醒功能,即在终止日期前一个月将该条记录显示在别处以提醒工作人员。实现思路:我在建表时多加了一项bool型数据,用来标记这项数据是否在一......
  • 电脑输入法仅桌面,登录系统后不能切换输入法,只能输出英文,不打出汉字
    情况描述:输入法在锁屏界面可以切换输入法,登录进去电脑后没办法切换输入法,而且没办法打出汉字。输入法显示仅桌面定位搜狗输入法路径我这里的搜狗输入法安装路径C:\ProgramF......
  • pytorch中GroupNorm的使用(笔记)
    不使用BatchNorm的原因  相比起batchNorm,groupNorm对一个batch的size更不敏感。在现有的研究中,batchNorm在batchsize越大的情况下,效果更佳,batchsize=32的情况明显优于ba......
  • vscode c语言printf打印中文,终端输出乱码问题
    具体步骤:文件->首选项->设置->下图图标进入settings.json在settings.json中添上"code-runner.executorMap":{"javascript":"node","java":"cd$dir&&javac$f......
  • 2023-3-8-pytorch关于CNN的描述与代码,ai对话生成的
    左脚踩右脚上天了是什么搞笑笑话这不是搞笑笑话,而是一个玩笑语言游戏。“左脚踩右脚上天了”实际上是一个不合逻辑的语句,因为脚无法踩在自己的脚上,更不可能踩上天。这种语......