首页 > 其他分享 >Conv层与BN层融合

Conv层与BN层融合

时间:2024-03-13 09:12:40浏览次数:16  
标签:Conv conv bn 融合 sqrt model BN gamma

Conv层与BN层融合

目录

简介

当前CNN卷积层的基本组成单元标配:Conv + BN +ReLU 三剑客,可以将BN层的运算融合到Conv层中,把三层减少为一层

减少运算量,加速推理。本质上是修改了卷积核的参数,在不增加Conv层计算量的同时,适用于模型推理。

BN(批归一化)层常用于在卷积层之后,对feature maps进行归一化,从而加速网络学习,也具有一定的正则化效果。训练时,BN需要学习一个minibatch数据的均值、方差,然后利用这些信息进行归一化。而在推理过程,通常为了加速,都会把BN融入到其上层卷积中,这样就将两步运算变成了一步,也就达到了加速目的。

要求

融合BN与卷积要求BN层位于卷积之后

且融合后的卷积层参数convolution_param中的bias_term必须为true。

原理

BN层参数

nn.Conv2d参数:

滤波器权重,			W:conv.weight
bias,				b:conv.bias
nn.BatchNorm2d参数:

scaling, γ:bn.weight
shift,  β:bn.bias
mean estimate,μ: bn.running_mean
variance estimate,σ^2 :bn.running_var
ϵ  (for numerical stability): bn.eps

BN层计算公式

pFDl1dU.png

在训练的时候,均值 \(\mu\) 、方差 $ \sigma^2$ 、 \(\gamma\) 、 \(\beta\) 是一直在更新的,在推理的时候,以上四个值都是固定了的,也就是推理的时候,均值和方差来自训练样本的数据分布。

因此,在推理的时候,上面BN的计算公式可以变形

\[{y}_{i}=\gamma \frac{x_{i}-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta=\frac{\gamma x_{i}}{\sqrt{\sigma^{2}+\epsilon}}+(\beta-\frac{\gamma \mu}{\sqrt{\sigma^{2}+\epsilon}}) \]

上面公式可以等价于

\[y_i=ax_i+b \]

\[a=\frac{\gamma}{\sqrt{\sigma^{2}+\epsilon}}\quad \quad b= (\beta-\frac{\gamma \mu}{\sqrt{\sigma^{2}+\epsilon}}) \]

$ \mu , \sigma^2 $ 为这个batch上计算得到的均值和方差(在B,H,W维度上计算,每个channel单独计算),而 \(\epsilon\) 是防止除零所设置的一个极小值, \(\gamma\) 是比例参数,而 \(\beta\)​​ 是平移系数。

此时,BN层转换成Conv层

Conv和BN计算合并

\[Y=\gamma \frac{(W*X+B)-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta \]

合并后

\[Y=\frac{ \gamma *W}{\sqrt{\sigma^{2}+\epsilon}} *X + \frac{\gamma*(B-\mu)}{\sqrt{\sigma^{2}+\epsilon}}+\beta \]

\[W_{merged}=\frac{\gamma}{\sqrt{\sigma^{2}+\epsilon}}* W = W*a \quad \quad B_{merged}= \frac{\gamma*(B-\mu)}{\sqrt{\sigma^{2}+\epsilon}}+\beta = (B- \mu)*a+ \beta \]

1.首先我们将测试阶段的BN层(一般称为frozen BN)等效替换为一个1x1卷积层

2.将卷积层与归一化层融合

pytorch-BN融合

    import torch
    import torchvision
    
    def fuse(conv, bn):
    
        fused = torch.nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            bias=True
        )
    
        # setting weights
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
        fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
        
        # setting bias
        if conv.bias is not None:
            b_conv = conv.bias
        else:
            b_conv = torch.zeros( conv.weight.size(0) )
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
                              torch.sqrt(bn.running_var + bn.eps)
                            )
        fused.bias.copy_( b_conv + b_bn )
    
        return fused
    
    # Testing
    # we need to turn off gradient calculation because we didn't write it
    torch.set_grad_enabled(False)
    x = torch.randn(16, 3, 256, 256)
    resnet18 = torchvision.models.resnet18(pretrained=True)
    # removing all learning variables, etc
    resnet18.eval()
    model = torch.nn.Sequential(
        resnet18.conv1,
        resnet18.bn1
    )
    f1 = model.forward(x)
    fused = fuse(model[0], model[1])
    f2 = fused.forward(x)
    d = (f1 - f2).mean().item()
    print("error:",d)

ONNX-BN融合

import onnx
import os
from onnx import optimizer

# Preprocessing: load the model contains two transposes.
# model_path = os.path.join('resources', 'two_transposes.onnx')
# original_model = onnx.load(model_path)
original_model = onnx.load("resne18.onnx")
# Check that the IR is well formed
onnx.checker.check_model(original_model) 
print('The model before optimization:\n\n{}'.format(onnx.helper.printable_graph(original_model.graph)))


# A full list of supported optimization passes can be found using get_available_passes()
all_passes = optimizer.get_available_passes()
print("Available optimization passes:")
for p in all_passes:
    print('\t{}'.format(p))
print()

# Pick one pass as example
passes = ['fuse_add_bias_into_conv']

# Apply the optimization on the original serialized model
optimized_model = optimizer.optimize(original_model, passes)

print('The model after optimization:\n\n{}'.format(onnx.helper.printable_graph(optimized_model.graph)))

# save new model
onnx.save(optimized_model, "newResnet18.onnx")

参考资料

bn层学习笔记 卷积层和BN层融合

pytorch中BN层和卷积层的merge

https://www.cnblogs.com/nowgood/p/juan-ji-ceng-he-liang-hua-ceng-rong-he.html?ivk_sa=1024320u

Conv和BN算子融合(参数重构)

标签:Conv,conv,bn,融合,sqrt,model,BN,gamma
From: https://www.cnblogs.com/tian777/p/18069795

相关文章

  • 关于树莓派5(Ubnutu 23.10和树莓派5自带的系统通用)下载时出现error: externally-manage
    一.报错产生的原因  最近作者更新了这两个系统,在作者想去安装非 Debian的库的时候总是出现以下的报错:error:externally-managed-environment这是因为树莓派5升级了服务器系统,从Debian11到了Debian12,这个服务器系统对于外接库的限制还是比较严格的。作者也按照系......
  • 智慧城市数据大融合的几点想法
        随着信息化的不断深入,产生了各种类型的数据,包括结构化数据和非结构化数据,用不同的方式呈现出来,如数值型、文本型、图形图像、音频视频、传感器信号等格式。这些数据来源于现实世界,描述了现实世界,根据这些描述现实世界的数据,我们应该可以归纳出一定的社会规律,自然规......
  • 使用 PMML 实现模型融合及优化技巧
    在机器学习的生产环境中,我们经常需要将多个模型的预测结果进行融合,以便提高预测的准确性。这个过程通常涉及到多个模型子分的简单逻辑回归融合。虽然离线训练时我们可以直接使用sklearn的逻辑回归进行训练和调参,但在生产环境中,模型的上线往往需要使用PMML(PredictiveModelMarkup......
  • TSINGSEE青犀煤矿矿井视频监控与汇聚融合管理视频监管平台建设方案
    一、背景需求随着我国经济的飞速发展,煤炭作为我国的主要能源之一,其开采和利用的重要性不言而喻。然而,煤矿事故频发,不仅造成了巨大的人员伤亡和财产损失,也对社会产生了深远的负面影响。视频监控系统作为实现煤矿智能化无人开采的关键系统与煤矿安全生产的多系统协同分析预处理的关......
  • wpf datagrid row background color alternatively changed based on row index,Alter
    <Windowx:Class="WpfApp7.MainWindow"xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d="http://schemas.microsoft.c......
  • 视频监控/云存储EasyCVR视频融合平台设备增删改操作不生效是什么原因?
    国标GB28181协议EasyCVR安防平台可以提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、云存储等丰富的视频能力,平台支持7*24小时实时高清视频监控,能同时播放多路监控视频流,视频画面1、4、9、16个可选,支持自定义......
  • 无线表格识别模型LORE转换库:ConvertLOREToONNX
    引言总有小伙伴问到阿里的无线表格识别模型是如何转换为ONNX格式的。这个说来有些惭愧,现有的ONNX模型是很久之前转换的了,转换环境已经丢失,且没有做任何笔记。今天下定决心再次尝试转换,庆幸的是转换成功了。于是有了转换笔记:ConvertLOREToONNX。这次吸取教训,环境文件采用Anacond......
  • P6810 「MCOI-02」Convex Hull 凸包 题解
    分析推式子题。\[ans=\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{m}\tau(i)\tau(j)\tau(\gcd(i,j))\]对于\((i,j)\),若\(k\)是\((i,j)\)的因子,则\(k\)一定整除\(i,j\),所以有:\[\\\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{m}\tau(i)\tau(j)\sum\limits......
  • golang标准库之 flag、strconv
    目录一、flag库1.flag的简单替代2.flag的参数类型3.flag参数的定义(1)flag.Type()(2)flag.TypeVar()4.flag解析命令行参数5.flag其他方法二、strconv库1.string转换为int类型2.int转换为string类型3.Parse系列函数(1)ParseBool()(2)ParseInt()(3)ParseUnit()(4)ParseFloat()(5)示例4.Fo......
  • 对象不能从 DBNull 转换为其他类型,数据库空数据映射实体类的时候如何处理数据
    场景是这样的数据库有几个字段是可以为空的、即插入的时候可以不插这些数据,当一条有‘缺口’的数据回到后端映射实体类的时候,会导致对象不能从DBNull转换为其他类型的错误此时可以编写一个通用的方法来处理这种转换publicstaticTConvertDBNull<T>(objectvalue,Tdefaul......