首页 > 其他分享 >【YOLOv8改进】BRA(bi-level routing attention ):双层路由注意力(论文笔记+引入代码)

【YOLOv8改进】BRA(bi-level routing attention ):双层路由注意力(论文笔记+引入代码)

时间:2024-06-22 09:20:29浏览次数:10  
标签:dim level downsample win self attention bi routing kv

摘要

作为视觉Transformers的核心构建模块,注意力机制是一种强大的工具,用于捕捉长程依赖关系。然而,这种强大功能也带来了代价:计算代价巨大且内存占用高,因为需要计算所有空间位置上成对的token交互。为缓解这一问题,一系列研究尝试通过引入手工设计且内容无关的稀疏性来改进注意力机制,例如将注意力操作限制在局部窗口、轴向条带或膨胀窗口内。与这些方法不同,我们提出了一种新颖的动态稀疏注意力机制,通过双层路由实现更加灵活且具有内容感知的计算分配。具体而言,对于一个查询,首先在粗略的区域级别过滤掉无关的键值对,然后在剩余候选区域(即路由区域)的联合中应用细粒度的token-to-token注意力。我们提供了一个简单而有效的双层路由注意力的实现,该实现利用稀疏性来节省计算和内存,同时仅涉及GPU友好的稠密矩阵乘法。基于所提出的双层路由注意力,我们提出了一种新的通用视觉Transformer,命名为BiFormer。BiFormer在查询自适应的方式下关注一小部分相关token,而不受其他无关token的干扰,因而在密集预测任务中享有良好的性能和高计算效率。在图像分类、目标检测和语义分割等多个计算机视觉任务中的实验证明了我们设计的有效性。代码可在https://github.com/rayleizhu/BiFormer获得。

摘要

摘要——高光谱图像(HSI)去噪对于高光谱数据的有效分析和解释至关重要。然而,同时建模全局和局部特征以增强HSI去噪的研究却很少。在本文中,我们提出了一种混合卷积和注意力网络(HCANet),该网络结合了卷积神经网络(CNN)和Transformers的优势。为了增强全局和局部特征的建模,我们设计了一个卷积和注意力融合模块,旨在捕捉长距离依赖关系和邻域光谱相关性。此外,为了改进多尺度信息聚合,我们设计了一个多尺度前馈网络,通过在不同尺度上提取特征来增强去噪性能。在主流HSI数据集上的实验结果表明,所提出的HCANet具有合理性和有效性。所提出的模型在去除各种复杂噪声方面表现出色。我们的代码可在https://github.com/summitgao/HCANet获得。

文章链接

论文地址:论文地址

代码地址:代码地址

参考代码:代码地址

基本原理

Bi-Level Routing Attention (BRA)是一种注意力机制,旨在解决多头自注意力机制(MHSA)的可扩展性问题。传统的注意力机制要求每个查询都要关注所有的键-值对,这在处理大规模数据时可能会导致计算和存储资源的浪费。BRA通过引入动态的、查询感知的稀疏注意力机制来解决这一问题。

BRA的关键思想是在粗粒度的区域级别上过滤出大部分不相关的键-值对,只保留少量的路由区域。然后,在这些路由区域的并集上应用细粒度的令牌-令牌注意力。这种方法使得每个查询只需关注少量相关的键-值对,从而提高了计算效率和内存利用率。

具体来说,BRA的实现包括以下步骤:

  1. 构建和修剪区域级别的有向图,以过滤出大部分不相关的键-值对。
  2. 在路由区域的并集上应用细粒度的令牌-令牌注意力,以实现动态的、查询感知的稀疏性。

yolov8 代码引入

 class BiLevelRoutingAttention(nn.Module):
    """
    n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
    kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
    topk: topk for window filtering
    param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
    param_routing: extra linear for routing
    diff_routing: wether to set routing differentiable
    soft_routing: wether to multiply soft routing weights 
    """
    def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
                 kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
                 topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
                 auto_pad=False):
        super().__init__()
        # local attention setting
        self.dim = dim
        self.n_win = n_win  # Wh, Ww
        self.num_heads = num_heads
        self.qk_dim = qk_dim or dim
        assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
        self.scale = qk_scale or self.qk_dim ** -0.5


        ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)
        
        ################ global routing setting #################
        self.topk = topk
        self.param_routing = param_routing
        self.diff_routing = diff_routing
        self.soft_routing = soft_routing
        # router
        assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
        self.router = TopkRouting(qk_dim=self.qk_dim,
                                  qk_scale=self.scale,
                                  topk=self.topk,
                                  diff_routing=self.diff_routing,
                                  param_routing=self.param_routing)
        if self.soft_routing: # soft routing, always diffrentiable (if no detach)
            mul_weight = 'soft'
        elif self.diff_routing: # hard differentiable routing
            mul_weight = 'hard'
        else:  # hard non-differentiable routing
            mul_weight = 'none'
        self.kv_gather = KVGather(mul_weight=mul_weight)

        # qkv mapping (shared by both global routing and local attention)
        self.param_attention = param_attention
        if self.param_attention == 'qkvo':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Linear(dim, dim)
        elif self.param_attention == 'qkv':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Identity()
        else:
            raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
        
        self.kv_downsample_mode = kv_downsample_mode
        self.kv_per_win = kv_per_win
        self.kv_downsample_ratio = kv_downsample_ratio
        self.kv_downsample_kenel = kv_downsample_kernel
        if self.kv_downsample_mode == 'ada_avgpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'ada_maxpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'maxpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'avgpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'identity': # no kv downsampling
            self.kv_down = nn.Identity()
        elif self.kv_downsample_mode == 'fracpool':
            # assert self.kv_downsample_ratio is not None
            # assert self.kv_downsample_kenel is not None
            # TODO: fracpool
            # 1. kernel size should be input size dependent
            # 2. there is a random factor, need to avoid independent sampling for k and v 
            raise NotImplementedError('fracpool policy is not implemented yet!')
        elif kv_downsample_mode == 'conv':
            # TODO: need to consider the case where k != v so that need two downsample modules
            raise NotImplementedError('conv policy is not implemented yet!')
        else:
            raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')

        # softmax for local attention
        self.attn_act = nn.Softmax(dim=-1)

        self.auto_pad=auto_pad

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139307690

标签:dim,level,downsample,win,self,attention,bi,routing,kv
From: https://www.cnblogs.com/banxia-frontend/p/18261865

相关文章

  • blender4.1添加骨骼复制位置和复制旋转约束代码(Armature-Biped_Root)
    添加旋转旋转约束importbpy#定义骨架中骨骼的映射关系bone_mapping={"mixamorig:Hips":"Pelvis","mixamorig:LeftUpLeg":"Left_Thigh","mixamorig:LeftLeg":"Left_Calf","mixamorig:LeftFoot&q......
  • [转] MySQL binlog 日志自动清理及手动删除
    参考转载自mysqlbinlog日志自动清理及手动删除-景岳-博客园说明当开启mysql数据库主从时,会产生大量如mysql-bin.00000*log的文件,这会大量耗费您的硬盘空间。mysql-bin.000001mysql-bin.000002mysql-bin.000003mysql-bin.000004mysql-bin.000005…有三种解......
  • 华为电脑BIOS设置系统启动顺序
        最近在华为电脑上装了Windows和Ubuntu双系统后,由于安装失误,导致每次开机后都会进入grub界面。    为了正常进入Windows和Ubuntu系统,在开机进入grub界面前,可以按F12进入bootmanager界面,在此界面下可以选择需要启动的系统。(请原谅我使用手机拍摄屏幕的方......
  • MySQL bit类型增加索引后查询结果不正确案例浅析
    昨天同事遇到的一个案例,这里简单描述一下:一个表里面有一个bit类型的字段,同事在优化相关SQL的过程中,给这个表的bit类型的字段新增了一个索引,然后测试验证时,居然发现SQL语句执行结果跟不加索引不一样。加了索引后,SQL语句没有查询出一条记录,删除索引后,SQL语句就能查询出几十条记录。......
  • zabbix agent 日志文件轮询分析
    1、zabbixagent日志文件轮询分析的初衷zabbixagent的日志文件默认在/var/log/zabbix目录下面。默认/目录只有20G或者40G,随着运行时间越来越长日志文件也会变大,会占用磁盘空间 2、zabbixagent文件为什么会过大是由于加了一些自定义监控项,这些监控项在执行的时候会记录......
  • c#中path.combine的用法是什么
    原文链接:https://www.yisu.com/ask/29579392.html在C#中,Path.Combine()方法用于将两个或多个字符串路径组合成一个有效的路径。它接受多个字符串参数作为路径的组成部分,并返回一个字符串,表示有效的路径。语法如下:publicstaticstringCombine(paramsstring[]paths);参数pa......
  • OCS2_mobile_manipulator案例详解
    1.启动共启动3个节点mobile_manipulator_mpc_node//mpc问题构建,计算mobile_manipulator_dummy_mrt_node//仿真,承接MPC的输出,发布Observation,对于仿真来讲,状态发布也是反馈mobile_manipulator_target//交互发布target2.MobileManipulatorMpcNode.cppMobileManipula......
  • epub与mobi可以相互转换吗?
    在数字化时代,电子书格式多样,每种格式都有其独特的特点和适用场景。其中,EPUB和MOBI是两种非常流行的电子书格式。然而,有时候,用户可能会因为某种需求或限制,希望将EPUB格式的文件转换为MOBI格式。或者反过来讲MOBI格式转换为EPUB格式?那么epub与mobi如何实现相互转换呢?方法一:在线转换......
  • 结合zabbix监控mysql,让mysql性能飙升
      前段时间客户的系统突然出现mysql只读集群cpu飙升的情况,飙升到最高点的时候,甚至导致应用服务器GC,幸好应用有备份服务器,流量直接切过去,客户也无感知。但是这个只是临时的解决办法,总归要找到具体的原因,和开发同事查了两天的应用日志和mysql的慢日志,始终无法定位到具体的问题。......
  • rabbitMQ实战生产者-交换机-队列-消费者细谈
     生产者rabbitmq的配置创建交换机,创建queue,绑定交换机的routingkey到queue一,默认的exchange列表 二,将exchange的routingkey绑定到queue 三,生产端关心消息将发放哪个交换机,哪个routingkey,也可以用通配符(如calc.*,calc.#)匹配相应的routingkey mq服务匹配exchange,rout......