首页 > 其他分享 >【Block总结】CrossFormerBlock

【Block总结】CrossFormerBlock

时间:2025-01-08 09:33:38浏览次数:10  
标签:总结 dim group nn CrossFormerBlock pos self Block size

论文介绍

链接:https://arxiv.org/pdf/2108.00154

  • CrossFormerBlock模块提出:论文提出了一种名为CrossFormer的视觉Transformer模型,其中重点介绍了CrossFormerBlock模块的设计。
  • 研究背景:针对视觉任务中自注意力模块计算成本高、难以处理跨尺度交互的问题,CrossFormerBlock模块进行了针对性的优化。
  • 目的:旨在通过改进自注意力模块,提高视觉Transformer模型的效率和性能。

创新点

  • 跨尺度嵌入层(CEL):引入了金字塔结构,将视觉Transformer模型分为多个阶段,每个阶段开始时使用CEL来处理不同尺度的嵌入。
  • 长短距离注意力(LSDA):将自注意力模块拆分为短距离注意力(SDA)和长距离注意力(LDA),以降低计算成本并保持跨尺度交互。
  • 动态位置偏置(DPB):提出了一个基于MLP的模块,用于动态生成相对位置偏置,增强了模型对位置信息的处理能力。
    在这里插入图片描述

方法

  • CEL实现:在每个阶段的开始,使用CEL将输入图像分割成不同尺度的块,并生成相应的嵌入。
  • LSDA实现:SDA通过分组相邻嵌入来计算依赖关系,LDA则通过采样具有固定间隔的嵌入来计算远程依赖关系。两者都使用标准的自注意力机制。
  • DPB实现:DPB接收两个嵌入的相对距离作为输入,并通过三个全连接层和非线性激活函数生成相对位置偏置。

模块作用

  • CEL作用:提供跨尺度特征,使模型能够更好地处理不同尺度的输入图像。
  • LSDA作用:降低自注意力模块的计算成本,同时保持对跨尺度交互的建模能力。
  • DPB作用:为模型提供动态的位置信息,增强了对图像中物体位置关系的理解。

改进的效果

  • 计算成本降低:通过LSDA,CrossFormerBlock模块显著降低了自注意力模块的计算成本。
  • 性能提升:在图像分类、目标检测和实例分割等任务上,CrossFormer模型表现出色,特别是在密集预测任务上,如检测和分割,相较于其他模型具有显著优势。
  • 跨尺度交互增强:CrossFormerBlock模块通过CEL和LSDA的结合,有效地增强了模型对跨尺度交互的建模能力,从而提高了在复杂视觉任务上的性能。

该论文提出的CrossFormerBlock模块通过引入跨尺度嵌入层、长短距离注意力和动态位置偏置等创新点,显著降低了计算成本并提升了视觉Transformer模型的性能。在多个视觉任务上,CrossFormer模型都表现出了出色的性能,特别是在密集预测任务上。

代码

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_


# 论文地址:https://arxiv.org/pdf/2108.00154
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class DynamicPosBias(nn.Module):
    def __init__(self, dim, num_heads, residual):
        super().__init__()
        self.residual = residual
        self.num_heads = num_heads
        self.pos_dim = dim // 4
        self.pos_proj = nn.Linear(2, self.pos_dim)
        self.pos1 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.pos_dim),
        )
        self.pos2 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.pos_dim)
        )
        self.pos3 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.num_heads)
        )

    def forward(self, biases):
        if self.residual:
            pos = self.pos_proj(biases)  # 2Wh-1 * 2Ww-1, heads
            pos = pos + self.pos1(pos)
            pos = pos + self.pos2(pos)
            pos = self.pos3(pos)
        else:
            pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
        return pos

    def flops(self, N):
        flops = N * 2 * self.pos_dim
        flops += N * self.pos_dim * self.pos_dim
        flops += N * self.pos_dim * self.pos_dim
        flops += N * self.pos_dim * self.num_heads
        return flops


class Attention(nn.Module):
    r""" Multi-head self attention module with dynamic position bias.

    Args:
        dim (int): Number of input channels.
        group_size (tuple[int]): The height and width of the group.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 position_bias=True):

        super().__init__()
        self.dim = dim
        self.group_size = group_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.position_bias = position_bias

        if position_bias:
            self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)

            # generate mother-set
            position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
            position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
            biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))  # 2, 2Wh-1, 2W2-1
            biases = biases.flatten(1).transpose(0, 1).float()
            self.register_buffer("biases", biases)

            # get pair-wise relative position index for each token inside the group
            coords_h = torch.arange(self.group_size[0])
            coords_w = torch.arange(self.group_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.group_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.group_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_groups*B, N, C)
            mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        if self.position_bias:
            pos = self.pos(self.biases)  # 2Wh-1 * 2Ww-1, heads
            # select position bias
            relative_position_bias = pos[self.relative_position_index.view(-1)].view(
                self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 group with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        # x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        if self.position_bias:
            flops += self.pos.flops(N)
        return flops


class CrossFormerBlock(nn.Module):
    r""" CrossFormer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        group_size (int): Group size.
        lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, group_size=10, lsda_flag=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.group_size = group_size
        self.lsda_flag = lsda_flag
        self.mlp_ratio = mlp_ratio
        self.num_patch_size = num_patch_size
        if min(self.input_resolution) <= self.group_size:
            # if group size is larger than input resolution, we don't partition groups
            self.lsda_flag = 0
            self.group_size = min(self.input_resolution)

        self.norm1 = norm_layer(dim)

        self.attn = Attention(
            dim, group_size=to_2tuple(self.group_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
            position_bias=True)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # group embeddings
        G = self.group_size
        if self.lsda_flag == 0:  # 0 for SDA
            x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)
        else:  # 1 for LDA
            x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)
        x = x.reshape(B * H * W // G ** 2, G ** 2, C)

        # multi-head self-attention
        x = self.attn(x, mask=self.attn_mask)  # nW*B, G*G, C

        # ungroup embeddings
        x = x.reshape(B, H // G, W // G, G, G, C)
        if self.lsda_flag == 0:
            x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)
        else:
            x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # LSDA
        nW = H * W / self.group_size / self.group_size
        flops += nW * self.attn.flops(self.group_size * self.group_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

if __name__ == '__main__':
    # 创建一个随机输入张量,形状为 (batch_size,height×width,channels)
    input = torch.rand(1,40*32,64)

    # 实例化CrossFormerBlock模块
    block = CrossFormerBlock(dim=64,input_resolution=(40,32),num_heads=8,group_size=1)
    # 前向传播
    output = block(input)

    # 打印输入和输出的形状
    print(input.size())
    print(output.size())

标签:总结,dim,group,nn,CrossFormerBlock,pos,self,Block,size
From: https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/144983414

相关文章

  • 『杂题总结』Day11 略解
    前言只闻花香,不谈悲喜。饮茶颂书,不争朝夕。对BZ的题目彻底失望了,开始自己瞎搞了。1.CF2057E2标签:\(\textbf{Floyd}\)。首先先考虑朴素做法。考虑每次询问二分答案,边权比\(\text{mid}\)小的边当作\(0\),否则当作\(1\)。如果\(a\tob\)的最短路\(\lek\),那么就是合......
  • 一文总结PCB检验标准有哪些,工程师必看!
    PCB检验标准可以分为两大类:国际标准和国家标准。国际标准是由国际组织制定的,适用于不同国家和地区的PCB生产和使用。比如,IPC(国际印刷电路协会)制定了一系列的IPC标准,包括IPC-A-600(印刷电路板外观验收标准)、IPC-6012(刚性印刷电路板性能规范)等。国家标准是由各个国家根据自......
  • AI算法专家总结的超实用Prompt 技巧,让AI更好更高效地为你服务
    欢迎来到AI应用探索,这里专注于探索AI应用。Prompt是人与AI沟通的桥梁,直接决定了AI生成内容的相关性和准确性。一个好的Prompt能帮助AI准确理解需求、生成更精准和实用的结果,同时节省时间和精力,提升工作效率。不论是内容创作、问题解答还是数据分析,一个好的Prompt都是激发AI潜......
  • 2025年1月7日的AI新闻总结
    以下是2025年1月7日的AI新闻总结:OpenAI动态年终总结:OpenAI首席执行官萨姆・奥特曼发表年终总结,回顾了公司从创立到ChatGPT问世的过程,以及自己被解雇后重返公司的经历,展望2025年AI发展,认为会有第一批人工智能智能体“加入劳动力大军”,OpenAI也将朝着超级智能目标前进.政治压......
  • 2024 年度总结
    2024年度总结提示:在本文中你将看到我所拥有的某些负面情感,或许也有可以正向反馈的东西,但是以下的特质体现的淋漓尽致:趋利避害的天性天津人骨子里拥有的沮丧对于困难的畏惧对于未来的沮丧对于迷茫生活的无助缺少爱的能力对于周围环境的批评,不在自己身上找问题......
  • react项目性能优化实践经验总结
    1.代码片段执行时间console.time('xxx')//代码片段console.timeEnd('xxx')在代码片段包裹上述代码,执行后,命令行会输出该段代码的执行时间,非常方便。2.reactProfilereact的<Profiler/>包裹组件后,并传入id和onRender回调函数。id是一个唯一标识符,用于区分不同的Profiler......
  • 总结并拆解所有新手常用的——String API(一)
    前言:String类包括的方法可用于检查序列的单个字符、比较字符串、搜索字符串、提取子字符串、创建字符串副本并将所有字符全部转换为大写或小写.......小编这次就比较全面系统的带大家总结清楚几乎所有string常用的API,并且带大家拆解清楚,能够灵活使用!!!小编今天总算是回家......
  • 北大营题目总结
    回文路径考虑到一个事情,对于\(s\)来说,我去二分回文半径长度,我往右拓展时肯定时\(s\)不能拓展的时候才会选择向下从\(t\)的对应位置向右拓展,按照这样,我的匹配策略是唯一的,接下来就变成一个模拟题了。然后分一下奇偶回文串然后看一下回文中心在\(s\)还是\(t\)还是交界处......
  • noip2024比赛总结
    信息学竞赛对学生综合能力的要求较高,例如数学、逻辑思维、思考速度、思考全面性等各个方面,同时,其考察学生耐性、专注度、严谨性、刻苦程度等性格有关方面,是一门综合性强,难度高,学习过程坎坷曲折的学科竞赛。在考完联赛noip之后,我想分享一下自己的比赛感受与学习经验,希望能对同学们......
  • 【渗透测试术语总结】
    目录题记渗透测试常用专业术语 加更:暗网转大佬笔记一、攻击篇二、防守篇Top渗透测试常用专业术语     相信大家和我一样,搞不清这些专业名词的区别,所以我来整理一下。1.POC、EXP、Payload与ShellcodePOC:全称'ProofofConcept',中文'概念验证',常指一......