首页 > 其他分享 >SMA2:代码实现详解——Image Encoder篇(FpnNeck章)

SMA2:代码实现详解——Image Encoder篇(FpnNeck章)

时间:2024-08-10 22:25:03浏览次数:15  
标签:SMA2 features self fpn pos FpnNeck Encoder model 2i

SMA2:代码实现详解——Image Encoder篇(FpnNeck)

在这里插入图片描述

总配置YAML文件、OmegaConf和hydra

SAM2的官方实现是使用yaml文件来配置整体的模型结构与参数的。关键代码如下:

def build_sam2(
    config_file,
    ckpt_path=None,
    device="cuda",
    mode="eval",
    hydra_overrides_extra=[],
    apply_postprocessing=True,
):

    if apply_postprocessing:
        hydra_overrides_extra = hydra_overrides_extra.copy()
        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
        ]
    # Read config and init model
    cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
    OmegaConf.resolve(cfg)
    model = instantiate(cfg.model, _recursive_=True)
    _load_checkpoint(model, ckpt_path)
    model = model.to(device)
    if mode == "eval":
        model.eval()
    return model

从代码的第10行到第20行都是在配置模型参数。第19行的compose函数与第21行的instantiate函数都是hydra库的库函数。Hydra是一个开源Python框架,也是由Meta团队开发的,它可简化研究和其他复杂应用程序的开发。其主要功能是能够通过组合动态创建分层配置,并通过配置文件和命令行覆盖它。Hydra对yaml文件的读写操作是基于OmegaConf库的。

回到我们的代码,第19行的compose函数用来读取config_name参数指定的yaml文件,生成可类似于Dict访问的Python对象,并根据overrides参数的内容,覆盖从yaml得到的部分参数内容。

第21行的instantiate函数根据yaml文件中的配置信息实际构建网络模型。这个地方只用文字可能不太好理解,我们举个例子:
例子yaml文件:

optimizer:
  _target_: my_app.Optimizer
  algo: SGD
  lr: 0.01

例子class文件:

class Optimizer:
    algo: str
    lr: float

    def __init__(self, algo: str, lr: float) -> None:
        self.algo = algo
        self.lr = lr

例子实例化函数:

opt = instantiate(cfg.optimizer)
print(opt)
# Optimizer(algo=SGD,lr=0.01)

# override parameters on the call-site
opt = instantiate(cfg.optimizer, lr=0.2)
print(opt)
# Optimizer(algo=SGD,lr=0.2)

那么我们接下来见一下SMA2的具体构造(以tiny版本为例):

model:
  _target_: sam2.modeling.sam2_base.SAM2Base
  image_encoder:
    _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
    scalp: 1
    trunk:
      _target_: sam2.modeling.backbones.hieradet.Hiera
      embed_dim: 96
      num_heads: 1
      stages: [1, 2, 7, 2]
      global_att_blocks: [5, 7, 9]
      window_pos_embed_bkg_spatial_size: [7, 7]
    neck:
      _target_: sam2.modeling.backbones.image_encoder.FpnNeck
      position_encoding:
        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
        num_pos_feats: 256
        normalize: true
        scale: null
        temperature: 10000
      d_model: 256
      backbone_channel_list: [768, 384, 192, 96]
      fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
      fpn_interp_model: nearest

  memory_attention:
    _target_: sam2.modeling.memory_attention.MemoryAttention
    d_model: 256
    pos_enc_at_input: true
    layer:
      _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
      activation: relu
      dim_feedforward: 2048
      dropout: 0.1
      pos_enc_at_attn: false
      self_attention:
        _target_: sam2.modeling.sam.transformer.RoPEAttention
        rope_theta: 10000.0
        feat_sizes: [32, 32]
        embedding_dim: 256
        num_heads: 1
        downsample_rate: 1
        dropout: 0.1
      d_model: 256
      pos_enc_at_cross_attn_keys: true
      pos_enc_at_cross_attn_queries: false
      cross_attention:
        _target_: sam2.modeling.sam.transformer.RoPEAttention
        rope_theta: 10000.0
        feat_sizes: [32, 32]
        rope_k_repeat: True
        embedding_dim: 256
        num_heads: 1
        downsample_rate: 1
        dropout: 0.1
        kv_in_dim: 64
    num_layers: 4

  memory_encoder:
      _target_: sam2.modeling.memory_encoder.MemoryEncoder
      out_dim: 64
      position_encoding:
        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
        num_pos_feats: 64
        normalize: true
        scale: null
        temperature: 10000
      mask_downsampler:
        _target_: sam2.modeling.memory_encoder.MaskDownSampler
        kernel_size: 3
        stride: 2
        padding: 1
      fuser:
        _target_: sam2.modeling.memory_encoder.Fuser
        layer:
          _target_: sam2.modeling.memory_encoder.CXBlock
          dim: 256
          kernel_size: 7
          padding: 3
          layer_scale_init_value: 1e-6
          use_dwconv: True  # depth-wise convs
        num_layers: 2

  num_maskmem: 7
  image_size: 1024
  # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
  # SAM decoder
  sigmoid_scale_for_mem_enc: 20.0
  sigmoid_bias_for_mem_enc: -10.0
  use_mask_input_as_output_without_sam: true
  # Memory
  directly_add_no_mem_embed: true
  # use high-resolution feature map in the SAM mask decoder
  use_high_res_features_in_sam: true
  # output 3 masks on the first click on initial conditioning frames
  multimask_output_in_sam: true
  # SAM heads
  iou_prediction_use_sigmoid: True
  # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
  use_obj_ptrs_in_encoder: true
  add_tpos_enc_to_obj_ptrs: false
  only_obj_ptrs_in_the_past_for_eval: true
  # object occlusion prediction
  pred_obj_scores: true
  pred_obj_scores_mlp: true
  fixed_no_obj_ptr: true
  # multimask tracking settings
  multimask_output_for_tracking: true
  use_multimask_token_for_obj_ptr: true
  multimask_min_pt_num: 0
  multimask_max_pt_num: 1
  use_mlp_for_obj_ptr_proj: true
  # Compilation flag
  # HieraT does not currently support compilation, should always be set to False
  compile_image_encoder: False

如同我们在SMA2里面所讲的那样,SMA2模型由image_encodermemory_attentionmemory_encoder所构成(见Yaml的第3,26,59行)。

Image Encoder

从yaml文件中,我们可以清晰的看到,Image Encoder由两部分组成,分别是Hiera模型作为trunkFpnNeck作为neck
Hiera是一个掩码自编码器MAE,是论文"Hiera: A hierarchical vision transformer without the bells-and-whistles. ICML, 2023."中提出的预训练模型。使用Hiera的编码器提取特征,并使用特征金字塔
(FPN,FpnNeck)来融合提取出的特征。

接下来我们看一下Image Encoder的代码:

class ImageEncoder(nn.Module):
    def __init__(
        self,
        trunk: nn.Module,
        neck: nn.Module,
        scalp: int = 0,
    ):
        super().__init__()
        self.trunk = trunk
        self.neck = neck
        self.scalp = scalp
        assert (
            self.trunk.channel_list == self.neck.backbone_channel_list
        ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"

    def forward(self, sample: torch.Tensor):
        # Forward through backbone
        features, pos = self.neck(self.trunk(sample))
        if self.scalp > 0:
            # Discard the lowest resolution features
            features, pos = features[: -self.scalp], pos[: -self.scalp]

        src = features[-1]
        output = {
            "vision_features": src,
            "vision_pos_enc": pos,
            "backbone_fpn": features,
        }
        return output

关键代码是第18行,样本在ImageEncoder内部先经过trunk,然后再经过neck。实际上就是先使用Hiera处理得到结果,然后使用FpnNeck处理。
FPN其实在图像领域是一个比较早的技术了,和他的名称相同,一目了然。这里就大概解释一下,比如模块中的position_encoding并未对x做操作,只是根据x的形状得到了pos
在这里插入图片描述

Neck:FpnNeck

class FpnNeck(nn.Module):
   
    '''
    根据yaml中的配置:
    d_model=256,
    backbone_channel_list=[768, 384, 192, 96]
    fpn_top_down_levels=[2, 3]
    fpn_interp_model=nearest
    '''

    def __init__(
        self,
        position_encoding: nn.Module,
        d_model: int,  
        backbone_channel_list: List[int],
        kernel_size: int = 1,
        stride: int = 1,
        padding: int = 0,
        fpn_interp_model: str = "bilinear",
        fuse_type: str = "sum",
        fpn_top_down_levels: Optional[List[int]] = None,
    ):
        super().__init__()
        self.position_encoding = position_encoding
        self.convs = nn.ModuleList()
        self.backbone_channel_list = backbone_channel_list
        for dim in backbone_channel_list:
            current = nn.Sequential()
            current.add_module(  ## 跳步连接中的1阶算子
                "conv",
                nn.Conv2d(
                    in_channels=dim,
                    out_channels=d_model,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                ),
            )

            self.convs.append(current)
        self.fpn_interp_model = fpn_interp_model
        assert fuse_type in ["sum", "avg"]
        self.fuse_type = fuse_type

        # levels to have top-down features in its outputs
        # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
        # have top-down propagation, while outputs of level 0 and level 1 have only
        # lateral features from the same backbone level.
        if fpn_top_down_levels is None:
            # default is to have top-down features on all levels
            fpn_top_down_levels = range(len(self.convs))
        self.fpn_top_down_levels = list(fpn_top_down_levels)

    def forward(self, xs: List[torch.Tensor]):

        out = [None] * len(self.convs)
        pos = [None] * len(self.convs)
        assert len(xs) == len(self.convs)
        # fpn forward pass
        # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
        prev_features = None
        # forward in top-down order (from low to high resolution)
        n = len(self.convs) - 1
        for i in range(n, -1, -1):
            x = xs[i]
            lateral_features = self.convs[n - i](x)
            if i in self.fpn_top_down_levels and prev_features is not None:
                top_down_features = F.interpolate(
                    prev_features.to(dtype=torch.float32),
                    scale_factor=2.0,
                    mode=self.fpn_interp_model,
                    align_corners=(
                        None if self.fpn_interp_model == "nearest" else False
                    ),
                    antialias=False,
                )
                prev_features = lateral_features + top_down_features
                if self.fuse_type == "avg":
                    prev_features /= 2
            else:
                prev_features = lateral_features
            x_out = prev_features
            out[i] = x_out
            pos[i] = self.position_encoding(x_out).to(x_out.dtype)

        return out, pos

interpolate函数做上采样,conv 1 × 1 1\times 1 1×1算子将每个特征映射到相同的维度d_model。数据流转形式和上面的图片是一致。

我们可以从代码67行的条件语句可以看出,模型只针对fpn_top_down_levels中指定的步骤所得出的特征做FPN融合。输出结果是一个元组(out, pos),我们先看out,out是一个元素全为tensor的列表,每个tensor的形状应为(…,d_model, x.shape[1], x.shape[2])。

class PositionEmbeddingSine(nn.Module): ## 传入position_encoding实例的类定义
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """

    def __init__(
        self,
        num_pos_feats,
        temperature: int = 10000,
        normalize: bool = True,
        scale: Optional[float] = None,
    ):
      ...
    @torch.no_grad()
    def forward(self, x: torch.Tensor):
        y_embed = (
            torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
            .view(1, -1, 1)
            .repeat(x.shape[0], 1, x.shape[-1])
        )
        x_embed = (
            torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
            .view(1, 1, -1)
            .repeat(x.shape[0], x.shape[-2], 1)
        )

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack(
            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos_y = torch.stack(
            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        self.cache[cache_key] = pos[0]
        return pos

他的官方注释也注明了,它非常类似于Attention is all you need中的位置编码:
p k , 2 i = s i n ( k 1000 0 2 i / d ) p k , 2 i + 1 = c o s ( k 1000 0 2 i / d ) p_{k, 2i}=sin\left(\frac{k}{10000^{2i/d}}\right)\\ p_{k, 2i+1}=cos\left(\frac{k}{10000^{2i/d}}\right) pk,2i​=sin(100002i/dk​)pk,2i+1​=cos(100002i/dk​)

代码84、85两行就是在计算 1000 0 2 i / d 10000^{2i/d} 100002i/d。87、88两行分别计算了pos_x与pos_y的 k 1000 0 2 i / d \frac{k}{10000^{2i/d}} 100002i/dk​.
89-94行则分别完成了对pos_x和pos_y的位置编码计算。

注意类似而非相同。代码所示的计算方式如下:

  • 对于pos_x:
    p x , y , 2 i = s i n ( i 1000 0 2 i / d ) p x , y , 2 i + 1 = c o s ( i 1000 0 2 i / d ) p_{x, y, 2i}=sin\left(\frac{i}{10000^{2i/d}}\right)\\ p_{x, y, 2i+1}=cos\left(\frac{i}{10000^{2i/d}}\right) px,y,2i​=sin(100002i/di​)px,y,2i+1​=cos(100002i/di​)

  • 对于pos_y:
    p x , y , 2 i = s i n ( y 1000 0 2 i / d ) p x , y , 2 i + 1 = c o s ( y 1000 0 2 i / d ) p_{x, y, 2i}=sin\left(\frac{y}{10000^{2i/d}}\right)\\ p_{x, y, 2i+1}=cos\left(\frac{y}{10000^{2i/d}}\right) px,y,2i​=sin(100002i/dy​)px,y,2i+1​=cos(100002i/dy​)

写在后面

感觉对于代码讲解blog,是不是用视频的形式更好一点

标签:SMA2,features,self,fpn,pos,FpnNeck,Encoder,model,2i
From: https://blog.csdn.net/qq_36553572/article/details/141091851

相关文章

  • 【arxiv 2024】VideoGPT+: Integrating Image and Video Encoders for Enhanced Video
    【arxiv2024】VideoGPT+:IntegratingImageandVideoEncodersforEnhancedVideoUnderstanding一、前言Abstract1Introduction2RelatedWorks3Method4Dataset5ProposedBenchmark6Experiments7Conclusion8QualitativeResults9AdditionalImplementation......
  • Video Encoder LowLatency模式
    低延迟模式(LowLatencyMode)在视频编码器中是一种优化设置,旨在减少编码和解码过程中的延迟,使视频能够尽快传输和播放。这种模式特别适用于实时应用,例如视频会议、游戏流媒体和直播等。主要用途视频会议:确保与会者之间的对话具有最小的延迟,使交流更自然和流畅。游戏流媒体:......
  • Encoder
    Encoder#导入包importtorchfromtorchimportnnimporttorch.nn.functionalasfimportmathclassPositionwiseFeedForward(nn.Module):def__init__(self,d_model,hidden,dropout=0.1):"""d_model:输入特征维度hi......
  • Transformer模型:Encoder的self-attention mask实现
    前言         这是对Transformer模型的WordEmbedding、PostionEmbedding内容的续篇。视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili文章链接:Transformer模型:WordEmbedding实现-CSDN博客          Transf......
  • VAE(Variational auto-encoder)
    1.VAE(Variationalauto-encoder)笔记来源及推荐文章:1.变分自编码器(一):原来是这么一回事2.变分自编码器(二):从贝叶斯观点出发3.变分自编码器(三):这样做为什么能成?4.变分自编码器(四):一步到位的聚类方案5.变分自编码器=最小化先验分布+最大化互信息6.变分自编码器(六):从几何......
  • T5架构和主流llama3架构有什么区别和优缺点、transformer中encoder 和decoder的不同、
    T5架构和主流llama3架构有什么区别和优缺点T5和LLaMA是两种在自然语言处理(NLP)领域广泛应用的大型语言模型,它们在架构和应用上有显著的区别和各自的优缺点。T5架构架构特点:Encoder-Decoder结构:T5(Text-to-TextTransferTransformer)采用了经典的Encoder-DecoderTransform......
  • jdk17 package sun.misc.BASE64Encoder 报找不到符号
    解决办法:对于只需要Base64编码的简单场景,可以考虑使用Java8及以上版本提供的内置`java.util.Base64`类,这是一个标准API,无需额外引入任何库: //importsun.misc.BASE64Encoder;//java8importjava.util.Base64;//java17publicclassBase64EncodingExample{......
  • 使用 PyTorch 创建的多步时间序列预测的 Encoder-Decoder 模型
    Encoder-decoder模型在序列到序列的自然语言处理任务(如语言翻译等)中提供了最先进的结果。多步时间序列预测也可以被视为一个seq2seq任务,可以使用encoder-decoder模型来处理。本文提供了一个用于解决Kaggle时间序列预测任务的encoder-decoder模型,并介绍了获得前10%结果......
  • AI模型大宗师Transformer的Encoder魔法棒
       在AI大模型数字王国里,有一位名叫Transformer的魔法大宗师。他有一个神奇的百宝箱,里面有很多魔法工具,其中有个工具叫Encoder,这个工具拥有一种神奇的力量,可以将复杂的输入信息进行编码,提取出关键的特征和依赖关系。   让我们来一起把这个工具掏出来细看一下,看看......
  • Transformer详解encoder
    目录1.InputEmbedding2.PositionalEncoding3.Multi-HeadAttention4.Add&Norm5.Feedforward+Add&Norm6.代码展示(1)layer_norm(2)encoder_layer=1最近刚好梳理了下transformer,今天就来讲讲它~        Transformer是谷歌大脑2017年在论文attention......