首页 > 其他分享 >YOLOv8改进 - 注意力篇 - 引入(A2-Nets)Double Attention Networks注意力机制

YOLOv8改进 - 注意力篇 - 引入(A2-Nets)Double Attention Networks注意力机制

时间:2024-09-29 12:47:31浏览次数:19  
标签:DoubleAttention nn Double self Networks init A2 Nets 注意力

一、本文介绍

作为入门性篇章,这里介绍了A2-Nets网络注意力在YOLOv8中的使用。包含A2-Nets原理分析,A2-Nets的代码、A2-Nets的使用方法、以及添加以后的yaml文件及运行记录。

二、A2-Nets原理分析

A2-Nets官方论文地址:A2-Nets文章

A2-Nets注意力机制(双重注意力机制):它从输入图像/视频的整个时空空间中聚集和传播信息全局特征,使后续卷积层能够有效地从整个空间中访问特征。采用双注意机制(包括Spatial Attention和Channel Attention。Spatial Attention用于捕获图像中不同空间位置的重要性,而Channel Attention用于捕获图像中不同通道的重要性),分两步进行设计,第一步通过二阶注意池将整个空间的特征聚集成一个紧凑的集合,第二步通过另一个注意自适应地选择特征并将其分配到每个位置。​

相关代码:

A2-Nets注意力的代码,如下。

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F


class DoubleAttention(nn.Module):

    def __init__(self, in_channels, c_m=128, c_n=128, reconstruct=True):
        super().__init__()
        self.in_channels = in_channels
        self.reconstruct = reconstruct

        self.c_m = c_m
        self.c_n = c_n
        self.convA = nn.Conv2d(in_channels, c_m, 1)
        self.convB = nn.Conv2d(in_channels, c_n, 1)
        self.convV = nn.Conv2d(in_channels, c_n, 1)
        if self.reconstruct:
            self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size=1)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h, w = x.shape
        assert c == self.in_channels
        A = self.convA(x)  # b,c_m,h,w
        B = self.convB(x)  # b,c_n,h,w
        V = self.convV(x)  # b,c_n,h,w
        tmpA = A.view(b, self.c_m, -1)
        attention_maps = F.softmax(B.view(b, self.c_n, -1))
        attention_vectors = F.softmax(V.view(b, self.c_n, -1), dim=-1)
        # step 1: feature gating
        global_descriptors = torch.bmm(tmpA, attention_maps.permute(0, 2, 1))  # b.c_m,c_n
        # step 2: feature distribution
        tmpZ = global_descriptors.matmul(attention_vectors)  # b,c_m,h*w
        tmpZ = tmpZ.view(b, self.c_m, h, w)  # b,c_m,h,w
        if self.reconstruct:
            tmpZ = self.conv_reconstruct(tmpZ)

        return tmpZ

四、YOLOv8中DoubleAttention使用方法

1.YOLOv8中添加DoubleAttention模块:

首先在ultralytics/nn/modules/conv.py最后添加DoubleAttention模块的代码。

2.在conv.py的开头__all__ = 内添加DoubleAttention模块的类别名(A2-Nets的类别名在本文中为DoubleAttention)

3.在同级文件夹下的__init__.py内添加A2-Nets的相关内容:(分别是from .conv import DoubleAttention ;以及在__all__内添加DoubleAttention)

4.在ultralytics/nn/tasks.py进行SK注意力机制的注册,以及在YOLOv8的yaml配置文件中添加DoubleAttention即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:(本文续接上篇文章,加在了CBAM、ECA、SKAttention的位置)

        elif m in {CBAM,ECA,SKAttention,DoubleAttention}:#添加注意力模块,没有CBAM、ECA、SKAttention的,将CBAM、ECA、SKAttention删除即可
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]

然后,就是新建一个名为YOLOv8_DoubleAttention.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_DoubleAttention.yaml)

# Ultralytics YOLO 

标签:DoubleAttention,nn,Double,self,Networks,init,A2,Nets,注意力
From: https://blog.csdn.net/2301_79619145/article/details/142601192

相关文章

  • 【GAN】生成对抗网络Generative Adversarial Networks理解摘要
    【Pytorch】生成对抗网络实战_pytorch生成对抗网络-CSDN博客【损失函数】KL散度与交叉熵理解-CSDN博客  [1406.2661]GenerativeAdversarialNetworks(arxiv.org)GAN本质是对抗或者说竞争,通过生成器和鉴别器的竞争获取有效地结果,换句话说,GAN是在养蛊,大量数据和批次的......
  • 要求实现一个函数 DoubleToStr(double a,int b,char * str),将参数 a 转化为字符串 str
    sprintf函数:sprintf(str,"%.*f",b,a);:sprintf是一个格式化输出函数,类似于printf,但它将输出写入到字符串中而不是标准输出。"%.*f":#include<stdio.h>//将双精度浮点数a转换为字符串str,小数点后保留b位voidDoubleToStr(doublea,intb,char*str){  //......
  • YOLOv8改进 - 注意力篇 - 引入SK网络注意力机制
    一、本文介绍作为入门性篇章,这里介绍了SK网络注意力在YOLOv8中的使用。包含SK原理分析,SK的代码、SK的使用方法、以及添加以后的yaml文件及运行记录。二、SK原理分析SK官方论文地址:SK注意力文章SK注意力机制:SK网络中的神经元可以捕获具有不同比例的目标对象,实验验证了神经......
  • EfficientViT(2023CVPR):具有级联组注意力的内存高效视觉Transformer!
    EfficientViT:MemoryEfficientVisionTransformerwithCascadedGroupAttentionEfficientViT:具有级联组注意力的内存高效视觉Transformer万文长字,请耐心观看~论文地址:https://arxiv.org/abs/2305.07027代码地址:Cream/EfficientViTatmain·microsoft/Cream......
  • CAS-ViT:用于高效移动应用的卷积加法自注意力视觉Transformer
    近年来,VisionTransformer(ViT)在计算机视觉领域取得了巨大突破。然而ViT模型通常计算复杂度高,难以在资源受限的移动设备上部署。为了解决这个问题,研究人员提出了ConvolutionalAdditiveSelf-attentionVisionTransformers(CAS-ViT),这是一种轻量级的ViT变体,旨在在效率和性......
  • COMP3331/9331 Computer Networks and Applications
    COMP3331/9331ComputerNetworksandApplicationsAssignmentforTerm3,2024BitTrickleFileSharing System1. Goal and Learning ObjectivesIn this assignment you will have the opportunity to implement BitTrickle, apermissioned,peer-to- pee......
  • 实型(浮点型):float、double
    实型(浮点型):float、double实型变量也可以称为浮点型,浮点型变量是用来存储小数数值的。在C语言中,浮点型分为两种:单精度浮点型(float)、双精度浮点型(double),但是double型变量所表示的浮点数比float型变量更精确。由于浮点型变量是由有限的存储单元组成,因......