首页 > 其他分享 >【YOLOv8改进 - 注意力机制】GC Block: 全局上下文块,高效捕获特征图中的全局依赖关系

【YOLOv8改进 - 注意力机制】GC Block: 全局上下文块,高效捕获特征图中的全局依赖关系

时间:2024-07-18 17:30:03浏览次数:20  
标签:conv nn self YOLOv8 C2f context GC 全局 channel

YOLOv8目标检测创新改进与实战案例专栏

专栏目录: YOLOv8有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLOv8基础解析+创新改进+实战案例

介绍

image-20240718164344690

摘要

非局部网络(NLNet)通过聚合特定查询位置的全局上下文,为捕捉长程依赖性提供了开创性的方法。然而,通过严格的实证分析,我们发现非局部网络在同一图像的不同查询位置所建模的全局上下文几乎相同。在本文中,我们利用这一发现,创建了一个基于与查询无关公式的简化网络,该网络在保持NLNet准确性的同时,大幅减少了计算量。我们进一步观察到,这种简化设计在结构上与挤压-激励网络(SENet)相似。因此,我们将它们统一到一个三步通用框架中,用于全局上下文建模。在这一通用框架内,我们设计了一个更好的实例,称为全局上下文(GC)块,它轻量化且能有效建模全局上下文。由于其轻量化特性,我们可以将其应用于骨干网络的多个层次,构建一个全局上下文网络(GCNet),该网络在各种识别任务的主要基准测试中普遍优于简化的NLNet和SENet。代码和配置发布在:https://github.com/xvjiarui/GCNet。

文章链接

论文地址:论文地址

代码地址:代码地址

参考代码代码地址

基本原理

GC Block 详细介绍

全局上下文块(Global Context Block, GC Block)是Global Context Network(GCNet)的核心组件,设计用来高效捕获特征图中的全局依赖关系。它结合了非局部网络(NLNet)和挤压-激励网络(SENet)的优势,具体结构如下:

1. 上下文建模模块(Context Modeling Module)

这个模块的主要目的是聚合所有位置的特征形成全局上下文特征,具体步骤如下:

  • 输入特征图:假设输入特征图为 X ∈ R C × H × W X \in \mathbb{R}^{C \times H \times W} X∈RC×H×W,其中 C C C 表示通道数, H H H 和 W W W 分别表示特征图的高度和宽度。
  • 空间维度压缩:通过全局平均池化操作将空间维度压缩为单个向量,得到全局上下文特征 z ∈ R C z \in \mathbb{R}^C z∈RC:
    z = 1 H × W ∑ i = 1 H ∑ j = 1 W X i j z = \frac{1}{H \times W} \sum_{i=1}^H \sum_{j=1}^W X_{ij} z=H×W1​i=1∑H​j=1∑W​Xij​
  • 注意力权重计算:使用一个全连接层或1x1卷积层,将全局上下文特征变换为注意力权重 W z ∈ R C W_z \in \mathbb{R}^{C} Wz​∈RC:
    W z = σ ( W 1 z ) W_z = \sigma(W_1 z) Wz​=σ(W1​z)
    其中, W 1 W_1 W1​ 是可学习的权重矩阵, σ \sigma σ 是激活函数(通常为softmax)。
2. 特征变换模块(Feature Transform Module)

这个模块用于捕获特征图中通道之间的依赖关系:

  • 瓶颈变换:使用两层1x1卷积和ReLU激活函数,进行瓶颈变换以减少计算复杂度:
    y = W 2 ( δ ( W 1 X ) ) y = W_2 (\delta(W_1 X)) y=W2​(δ(W1​X))
    其中, W 1 W_1 W1​ 和 W 2 W_2 W2​ 是可学习的权重矩阵, δ \delta δ 是ReLU激活函数。
3. 特征融合模块(Feature Fusion Module)

这个模块的目的是将全局上下文特征融合到每个查询位置的特征中:

  • 特征融合:通过加法操作,将全局上下文特征 z z z 融合到每个位置的特征 X i j X_{ij} Xij​ 中,得到增强的特征图 Y ∈ R C × H × W Y \in \mathbb{R}^{C \times H \times W} Y∈RC×H×W:
    Y i j = X i j + z Y_{ij} = X_{ij} + z Yij​=Xij​+z

GC Block 流程总结

  1. 输入特征图: X ∈ R C × H × W X \in \mathbb{R}^{C \times H \times W} X∈RC×H×W。
  2. 全局上下文建模:通过全局平均池化和全连接层计算全局上下文特征 z z z。
  3. 特征变换:使用瓶颈变换模块对特征图进行变换。
  4. 特征融合:将全局上下文特征 z z z 融合到每个位置的特征 X i j X_{ij} Xij​ 中。

总结

GC Block 通过上下文建模、特征变换和特征融合三个模块,高效地捕获并利用图像中的全局上下文信息。这种设计不仅显著提高了模型在各种视觉识别任务中的性能,还保持了较低的计算成本和内存消耗。因此,GC Block 在实际应用中具有很高的实用价值。

核心代码

import torch
from mmcv.cnn import constant_init, kaiming_init
from torch import nn


def last_zero_init(m):
    if isinstance(m, nn.Sequential):
        constant_init(m[-1], val=0)
    else:
        constant_init(m, val=0)


class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None
        self.reset_parameters()

    def reset_parameters(self):
        if self.pooling_type == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True

        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out

下载YoloV8代码

直接下载

GitHub地址

image-20240116225427653

Git Clone

git clone https://github.com/ultralytics/ultralytics

安装环境

进入代码根目录并安装依赖。

image-20240116230741813

image-20240116230741813

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

在最新版本中,官方已经废弃了requirements.txt文件,转而将所有必要的代码和依赖整合进了ultralytics包中。因此,用户只需安装这个单一的ultralytics库,就能获得所需的全部功能和环境依赖。

pip install ultralytics

引入代码

在根目录下的ultralytics/nn/目录,新建一个 attention目录,然后新建一个以 GCBlock为文件名的py文件, 把代码拷贝进去。

以下代码来源于论文:

(通过这个论文看,yolov8+A+B+特定场景是真的好发论文!!!!)

https://github.com/RuiyangJu/YOLOv8_Global_Context_Fracture_Detection

不同版本的timm 导包路径可能不同!

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers.create_act import create_act_layer, get_act_layer
from timm.models.layers import make_divisible
from timm.models.layers.mlp import ConvMlp
from timm.models.layers.norm import LayerNorm2d


class GlobalContext(nn.Module):

    def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
                 rd_ratio=1. / 8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
        super(GlobalContext, self).__init__()
        act_layer = get_act_layer(act_layer)

        self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None

        if rd_channels is None:
            rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
        if fuse_add:
            self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
        else:
            self.mlp_add = None
        if fuse_scale:
            self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
        else:
            self.mlp_scale = None

        self.gate = create_act_layer(gate_layer)
        self.init_last_zero = init_last_zero
        self.reset_parameters()

    def reset_parameters(self):
        if self.conv_attn is not None:
            nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
        if self.mlp_add is not None:
            nn.init.zeros_(self.mlp_add.fc2.weight)

    def forward(self, x):
        B, C, H, W = x.shape

        if self.conv_attn is not None:
            attn = self.conv_attn(x).reshape(B, 1, H * W)  # (B, 1, H * W)
            attn = F.softmax(attn, dim=-1).unsqueeze(3)  # (B, 1, H * W, 1)
            context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
            context = context.view(B, C, 1, 1)
        else:
            context = x.mean(dim=(2, 3), keepdim=True)

        if self.mlp_scale is not None:
            mlp_x = self.mlp_scale(context)
            x = x * self.gate(mlp_x)
        if self.mlp_add is not None:
            mlp_x = self.mlp_add(context)
            x = x + mlp_x

        return x

注册

ultralytics/nn/tasks.py中进行如下操作:

步骤1:

 from ultralytics.nn.attention.GCBlock import GlobalContext

步骤2

修改def parse_model(d, ch, verbose=True):

        if m in (
            Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
            BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d,DWConvTranspose2d, C3x, RepC3, EVCBlock,
            CAFMAttention, CloFormerAttnConv, C2f_iAFF, CSPStage, nn.Conv2d, GSConv, VoVGSCSP, C2f_CBAM, C2f_RefConv,C2f_SimAM,
            DualConv, C2f_NAM, MSFE, C2f_MDCR,C2f_MHSA, PPA, C2f_DASI, C2f_CascadedGroupAttention, C2f_DWR, C2f_DWRSeg,
            SPConv, AKConv, C2f_DySnakeConv, ScConv, C2f_ScConv, RFAConv, RFCBAMConv, RFCAConv, CAConv, CBAMConv, C2f_MultiDilatelocalAttention,
            C2f_iRMB, C2f_MSBlock, MobileViTBlock,CoordAtt, MCALayer,C2f_FocusedLinearAttention, RCSOSA,RepNCSPELAN4, SPPELAN, S2Attention,
            C2f_DoubleAttention, Down_wt, GlobalContext
        ):
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)

            args = [c1, c2, *args[1:]]
            if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3, C2f_deformable_LKA,Sea_AttentionBlock, C2f_iAFF,CSPStage, VoVGSCSP,C2f_SimAM,
                    C2f_NAM, C2f_MDCR, C2f_DASI, C2f_DWR, C2f_DWRSeg , C2f_MultiDilatelocalAttention, C2f_iRMB, C2f_MSBlock, MobileViTBlock,C2f_FocusedLinearAttention
                    ):
                args.insert(2, n)  # number of repeats
                n = 1

image-20240718170846878

配置yolov8_GCBlock.yaml

ultralytics/ultralytics/cfg/models/v8/yolov8_GCBlock.yaml

# Ultralytics YOLO 

标签:conv,nn,self,YOLOv8,C2f,context,GC,全局,channel
From: https://blog.csdn.net/shangyanaf/article/details/140528152

相关文章

  • 【YOLOv8改进-SPPF】 AIFI : 基于注意力的尺度内特征交互,保持高准确度的同时减少计算
    YOLOv8目标检测创新改进与实战案例专栏专栏目录:YOLOv8有效改进系列及项目实战目录包含卷积,主干注意力,检测头等创新机制以及各种目标检测分割项目实战案例专栏链接:YOLOv8基础解析+创新改进+实战案例介绍摘要YOLO系列因其在速度和准确性之间的合理权衡,成为了......
  • 【YOLOv8改进 - 特征融合NECK】SDI:多层次特征融合模块,替换contact操作
    YOLOv8目标检测创新改进与实战案例专栏专栏目录:YOLOv8有效改进系列及项目实战目录包含卷积,主干注意力,检测头等创新机制以及各种目标检测分割项目实战案例专栏链接:YOLOv8基础解析+创新改进+实战案例介绍摘要在本文中,我们介绍了U-Netv2,一种用于医学图像分割......
  • VS快速全局查找Unity死循环代码
    1、编写一个死循环方法,然后运行调试vsusingUnityEngine;publicclassDeadLoop:MonoBehaviour{//StartiscalledbeforethefirstframeupdatevoidStart(){DeadLoopMethod();}voidDeadLoopMethod(){while(t......
  • SpringCloud项目的搭建
    一、相关依赖导入1.注册中心<dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-netflix-eureka-server</artifactId><version>4.1.0</version>......
  • 结合LangChain实现网页数据爬取
    LangChain非常强大的一点就是封装了非常多强大的工具可以直接使用。降低了使用者的学习成本。比如数据网页爬取。在其官方文档-网页爬取中,也有非常好的示例。应用场景信息爬取。RAG信息检索。实践应用需求说明从ceshiren网站中获取每个帖子的名称以及其对应的url信......
  • LangChain补充二:LCEL和Runnable更加方便的创建调用链
    https://www.alang.ai/langchain/101/lc05一:LCEL入门LangChain的设计围绕着让AI应用开发者能够方便地将多个流程连缀成一个AI应用的业务逻辑,包括Chain与Agent。每个流程都被封装成一个runnable(langchain_core.runnables),包括提示语模板、模型调用、输出解析器、工具......
  • 扩展欧几里得算法(exGcd)
    扩展欧几里得算法(ExtendedEuclideanalgorithm,EXGCD),常用于求\(ax+by=c\)的一组可行解。过程设\(ax_1+by_1=\gcd(a,b)\)\(bx_2+(a\modb)y_2=gcd(b,a\modb)\)由欧几里得算法:\(\gcd(a,b)=gcd(b,a\modb)\)所以:\(ax_1+by_1=bx_2+(a\modb)y_2\)又因为:\(a\mod......
  • LangChain补充一:一些小且有用的点
    一:LangChain表达式语言LCEL(LangChainExpressionLanguage)chain:我们可以将包括大模型调用在内的一组操作组成“链条”,即所谓“调用链”(一)概念LangChain提供的LangChainExpressionLanguage(LCEL)让开发可以很方便地将多个组件连接成AI工作流(或者说是调用链)。如下是一......
  • LangChain补充五:Agent之LangGraph的使用
    一:LangGraph入门https://www.51cto.com/article/781996.htmlhttps://blog.csdn.net/weixin_41496173/article/details/139023846https://blog.csdn.net/wjjc1017/article/details/138518087https://langchain-ai.github.io/langgraph/https://langchain-ai.github.io/langg......
  • LangChain补充四:Agent知识点和案例补充
    https://www.alang.ai/langchain/101/lc07一:基本流程和概念(一)概念LangChainAgent的核心思想是,使用大语言模型选择一系列要执行的动作。在Chain中,一系列动作是硬编码在代码中的。在Agent中,大语言模型被用作推理引擎,以确定要采取的动作及其顺序。它包括3个组件:规划:将任......