首页 > 编程语言 >BEVFormer开源算法逐行解析(二):Decoder和Det部分

BEVFormer开源算法逐行解析(二):Decoder和Det部分

时间:2024-09-01 10:28:03浏览次数:7  
标签:reference self Det BEVFormer query embed bev 逐行 proj

写在前面:

对于BEVFormer算法框架的整体理解,大家可以找到大量的资料参考,但是对于算法代码的解读缺乏详实的资料。因此,本系列的目的是结合代码实现细节、在tensor维度的变换中帮助读者对算法能有更直观的认识。

本系列我们将对BEVFormer公版代码(开源算法)进行逐行解析,以结合代码理解Bevformer原理,掌握算法细节,帮助读者们利用该算法框架开发感知算法。在本系列的最后笔者还将面向地平线的用户,指出地平线参考算法在开源算法基础上做出的修改及修改背后的考虑,在算法部署过程中为用户提供参考。

公版代码目录封装较好,且以注册器的方式调用模型,各个模块的调用关系可以从configs/bevformer中的config文件中清晰体现,我们以bevformer_tiny.py为例3解析代码,Encoder部分已经发出,见《BEVFormer开源算法逐行解析(一):Encoder部分》,本文主要关注BEVFormer的Decoder和Det部分。

对代码的解析和理解主要体现在代码注释中。

1 PerceptionTransformer:

功能:

  • 将encoder层输出的bev_embed传入decoder中
  • 将在BEVFormer中定义的query_embedding按通道拆分成通道数相同的query_pos和query,并传入decoder中;
  • 利用query_pos通过线性层reference_points生成reference_points,并传入decoder;该reference_points在decoder中的CustimMSDeformableAttention作为融合bev_embed的基准采样点,作用类似于two-stage目标检测中的Region Proposal ;
  • 返回inter_states, inter_references给cls_branches和reg_branches分支得到目标的种类和bboxes。

解析:

#详见《BEVFormer开源算法逐行解析(一):Encoder部分》,用于获得bev_embed#在decoder中利用CustimMSDeformableAttention将bev_embed与query融合bev_embed = self.get_bev_features(    mlvl_feats,    bev_queries,    bev_h,    bev_w,    grid_length=grid_length,    bev_pos=bev_pos,    prev_bev=prev_bev,    **kwargs)  # bev_embed shape: bs, bev_h*bev_w, embed_dimsbs = mlvl_feats[0].size(0)#object_query_embed:torch.Size([900, 512])#query_pos:torch.Size([900, 256]) #query:torch.Size([900, 256])query_pos, query = torch.split(    object_query_embed, self.embed_dims, dim=1)query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)query = query.unsqueeze(0).expand(bs, -1, -1)#reference_points:torch.Size([1, 900, 3])reference_points = self.reference_points(query_pos)reference_points = reference_points.sigmoid()init_reference_out = reference_points#query:torch.Size([900, 1, 256])query = query.permute(1, 0, 2)#query_pos:torch.Size([900, 1, 256])query_pos = query_pos.permute(1, 0, 2)#bev_embed:torch.Size([50*50, 1, 256]) bev_embed = bev_embed.permute(1, 0, 2)#进入decoder模块!inter_states, inter_references = self.decoder(    query=query,    key=None,    value=bev_embed,    query_pos=query_pos,    reference_points=reference_points,    reg_branches=reg_branches,    cls_branches=cls_branches,    spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),    level_start_index=torch.tensor([0], device=query.device),    **kwargs)#返回inter_states, inter_references#后续用于提供给cls_branches和reg_branches分支得到目标的种类和bboxesinter_references_out = inter_referencesreturn bev_embed, inter_states, init_reference_out, inter_references_out

2 DetectionTransformerDecoder

功能:

  • 循环进入6个相同的DetrTransformerDecoderLayer,一个DetrTransformerDecoderLayer包含 ('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'),每层输出output和reference_points;
  • 在6层DetrTransformerDecoderLayer遍历完成后,将6层输出的output和reference_points输出。

解析:

#output:torch.Size([900, 1, 256])output = queryintermediate = []intermediate_reference_points = []#循环进入6个相同的DetrTransformerDecoderLayer模块for lid, layer in enumerate(self.layers):    #reference_points_input:torch.Size([1, 900, 1, 2])    #该reference_points在decoder中的CustimMSDeformableAttention作为融合bev_embed的基准采样点    reference_points_input = reference_points[..., :2].unsqueeze(        2)  # BS NUM_QUERY NUM_LEVEL 2    #进入某一层DetrTransformerDecoderLayer    output = layer(        output,        *args,        reference_points=reference_points_input,        key_padding_mask=key_padding_mask,        **kwargs)    #output:torch.Size([1, 900, 256])    output = output.permute(1, 0, 2)    if reg_branches is not None:        #tmp:torch.Size([1, 900, 10])        tmp = reg_branches[lid](output)        assert reference_points.shape[-1] == 3        #new_reference_pointtorch.Size([1, 900, 3])        new_reference_points = torch.zeros_like(reference_points)        new_reference_points[..., :2] = tmp[            ..., :2] + inverse_sigmoid(reference_points[..., :2])        new_reference_points[..., 2:3] = tmp[            ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])        new_reference_points = new_reference_points.sigmoid()        reference_points = new_reference_points.detach()    output = output.permute(1, 0, 2)    if self.return_intermediate:        intermediate.append(output)        intermediate_reference_points.append(reference_points)        #在6层DetrTransformerDecoderLayer遍历完成后,将6层输出的output和reference_points输出。if self.return_intermediate:    return torch.stack(intermediate), torch.stack(        intermediate_reference_points)return output, reference_points

深色代码部分生成的reference_points结构见下图,其中inverse_sigmoid(pt_reference_points)即为reference_points/Linear(query_pos)

2.1 MultiheadAttention

功能:

  • object_query的多头自注意力机制,如下图所示。

解析:

embed_dim = 256kdim = embed_dimvdim = embed_dimqkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim  # Truenum_heads = 8dropout = 0.1batch_first = Falsehead_dim = embed_dim // num_headsassert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"factory_kwargs = {'device': 'cuda', 'dtype': None}in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim, **factory_kwargs))bias_k = bias_v = Noneadd_zero_attn = Falseout_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=True, **factory_kwargs)attn_mask = attn_mask  # Noneif batch_first:     query, key, value = [x.transpose(1, 0) for x in (query, key, value)]if not qkv_same_embed_dim:    # attn_output, attn_output_weights = F.multi_head_attention_forward(    #     query, key, value, self.embed_dim, self.num_heads,    #     self.in_proj_weight, self.in_proj_bias,    #     self.bias_k, self.bias_v, self.add_zero_attn,    #     self.dropout, self.out_proj.weight, self.out_proj.bias,    #     training=self.training,    #     key_padding_mask=key_padding_mask, need_weights=need_weights,    #     attn_mask=attn_mask, use_separate_proj_weight=True,    #     q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,    #     v_proj_weight=self.v_proj_weight)    passelse:    attn_output, attn_output_weights = F.multi_head_attention_forward(         query, key, value, _embed_dim, num_heads, in_proj_weight, in_proj_bias,         bias_k, bias_v, add_zero_attn, dropout, out_proj.weight, out_proj.bias,         training=True, key_padding_mask=None, need_weights=True, attn_mask=mhaf_attn_mask)    -------------------------------F.multi_head_attention_forward start----------------------------    out_proj_weight = out_proj.weight    out_proj_bias = out_proj.bias    key = key    value = value    embed_dim_to_check = embed_dim    use_separate_proj_weight = False    training = True    key_padding_mask = None    need_weights = True    q_proj_weight, k_proj_weight, v_proj_weight = None, None, None    static_k, static_v = None, None    # set up shape vars    tgt_len, bsz, embed_dim = query.shape  # torch.Size([900, 1, 256])    src_len, _, _ = key.shape    assert embed_dim == embed_dim_to_check, \        f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"    if isinstance(embed_dim, torch.Tensor):    #     # embed_dim can be a tensor when JIT tracing    #     head_dim = embed_dim.div(mhaf_num_heads, rounding_mode='trunc')        pass    else:        head_dim = embed_dim // num_heads    assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {mhaf_num_heads}"    if not use_separate_proj_weight:        # q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)        # -----------_in_projection_packed start-----------        # q, k, v, w, b = query, mhaf_key, mhaf_value, mhaf_in_proj_weight, mhaf_in_proj_bias        # E = query.size(-1)        if key is value:            # if query is mhaf_key:            #     # self-attention            #     return linear(query, mhaf_in_proj_weight, mhaf_in_proj_bias).chunk(3, dim=-1)            # else:            #     # encoder-decoder attention            #     w_q, w_kv = mhaf_in_proj_weight.split([E, E * 2])            #     if mhaf_in_proj_bias is None:            #         b_q = b_kv = None            #     else:            #         b_q, b_kv = mhaf_in_proj_bias.split([E, E * 2])            #     return (linear(query, w_q, b_q),) + linear(mhaf_key, w_kv, b_kv).chunk(2, dim=-1)            pass        else:            w_q, w_k, w_v = in_proj_weight.chunk(3)            if in_proj_bias is None:                # b_q = b_k = b_v = None                pass            else:                b_q, b_k, b_v = in_proj_bias.chunk(3)            # return linear(query, w_q, b_q), linear(mhaf_key, w_k, b_k), linear(mhaf_value, w_v, b_v)            # F.linear(x, A, b): return x @ A.T + b            query, key, value = F.linear(query, w_q, b_q), F.linear(key, w_k, b_k), F.linear(value, w_v, b_v)            #                                   query + pt_query_pos      query + pt_query_pos                 query    # ------------_in_projection_packed end------------    # else:    #     assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"    #     assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"    #     assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"    #     if in_proj_bias is None:    #         b_q = b_k = b_v = None    #     else:    #         b_q, b_k, b_v = in_proj_bias.chunk(3)    #     q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)    #    # reshape q, k, v for multihead attention and make em batch first    query = query.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 1, 256] -> [900, 8, 32] -> [8, 900, 32]    if static_k is None:        key = key.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 8, 32] -> [8, 900, 32]    # else:    #     # TODO finish disentangling control flow so we don't do in-projections when statics are passed    #     assert mhaf_static_k.size(0) == bsz * mhaf_num_heads, \    #         f"expecting static_k.size(0) of {bsz * mhaf_num_heads}, but got {mhaf_static_k.size(0)}"    #     assert mhaf_static_k.size(2) == head_dim, \    #         f"expecting static_k.size(2) of {head_dim}, but got {mhaf_static_k.size(2)}"    #     mhaf_key = mhaf_static_k    if static_v is None:        value = value.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 8, 32] -> [8, 900, 32]    # else:    #     # TODO finish disentangling control flow so we don't do in-projections when statics are passed    #     assert mhaf_static_v.size(0) == bsz * mhaf_num_heads, \    #         f"expecting static_v.size(0) of {bsz * mhaf_num_heads}, but got {mhaf_static_v.size(0)}"    #     assert mhaf_static_v.size(2) == head_dim, \    #         f"expecting static_v.size(2) of {head_dim}, but got {mhaf_static_v.size(2)}"    #     mhaf_value = mhaf_static_v    # update source sequence length after adjustments    src_len = key.size(1)    attn_output, attn_output_weights = _scaled_dot_product_attention(query, key, value, attn_mask, dropout)    # ------------_scaled_dot_product_attention start------------    # q: Tensor,    # k: Tensor,    # v: Tensor,    # attn_mask: Optional[Tensor] = None,    # dropout_p: float = 0.0,    B, Nt, E = query.shape  # torch.Size([8, 900, 32]), mhaf_key and mhaf_value is same shape.    query = query / math.sqrt(E)    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)    attn = torch.bmm(query, key.transpose(-2, -1))  # [8, 900, 32] @ [8, 32, 900] -> [8, 900, 900]    # if mhaf_attn_mask is not None:    #     attn += mhaf_attn_mask    attn = F.softmax(attn, dim=-1)    if dropout > 0.0:        attn = F.dropout(attn, p=dropout)    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)    output = torch.bmm(attn, value)  # [8, 900, 900] @ [8, 900, 32] -> # torch.Size([8, 900, 32])    # return output, attn    attn_output, attn_output_weights = output, attn    # -------------_scaled_dot_product_attention end-------------    # tgt_len: 900  # [8, 900, 32]->[900, 8, 32]->[900, 1, 256]    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)  # torch.Size([900, 1, 256])    attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)  # nn.Linearout = attn_output# ------------------------------self.attn end------------------------------# return mha_identity + self.dropout_layer(self.proj_drop(out))query = identity + dropout_layer(mha_proj_drop(out))# torch.Size([900, 1, 256]) + # torch.Size([900, 1, 256])

2.2 CustomMSDeformableAttention

功能:

  • 利用可变形注意力机制将encoder模块输出的bev_embed融入object_query,如下图所示;
  • 输出该层的output,将其作为下一层DetrTransformerDecoderLayer的输入,并利用该层output生成该层对应的reference_points。

解析:

#-------------------------CustomMSDeformableAttention init(in part)---------------------------------sampling_offsets = nn.Linear(ca_embed_dims, ca_num_heads * ca_num_levels * ca_num_points * 2).cuda()attention_weights = nn.Linear(ca_embed_dims, ca_num_heads * ca_num_levels * ca_num_points).cuda()value_proj = nn.Linear(ca_embed_dims, ca_embed_dims).cuda()output_proj = nn.Linear(ca_embed_dims, ca_embed_dims).cuda()#-------------------------CustomMSDeformableAttention init(in part)---------------------------------if value is None:    value = queryif identity is None:    identity = queryif query_pos is not None:    query = query + query_posif not self.batch_first:    # change to (bs, num_query ,embed_dims)    #query:torch.Size([1, 900, 256])    query = query.permute(1, 0, 2)    #value(即bev_embed):torch.Size([1, 50*50, 256])    value = value.permute(1, 0, 2)bs, num_query, _ = query.shapebs, num_value, _ = value.shapeassert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value#value(即bev_embed):torch.Size([1, 50*50, 256])value = self.value_proj(value)if key_padding_mask is not None:    value = value.masked_fill(key_padding_mask[..., None], 0.0)#value:torch.Size([1, 50*50, 8, 32]),为多头做准备value = value.view(bs, num_value, self.num_heads, -1)sampling_offsets = self.sampling_offsets(query).view(    bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)#    1,    900,          8,            1,             4,             2attention_weights = self.attention_weights(query).view(    bs, num_query, self.num_heads, self.num_levels * self.num_points)#    1,    900,          8,                      4,             attention_weights = attention_weights.softmax(-1)#attention_weights:torch.Size([1, 900, 8, 1, 32])attention_weights = attention_weights.view(bs, num_query,                                            self.num_heads,                                            self.num_levels,                                            self.num_points)#reference_points:torch.Size([1, 900, 1, 2])                                            if reference_points.shape[-1] == 2:    offset_normalizer = torch.stack(        [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)    #sampling_locations:torch.Size([1, 900, 8, 1, 4, 2])    sampling_locations = reference_points[:, :, None, :, None, :] \        + sampling_offsets \        / offset_normalizer[None, None, None, :, None, :]elif reference_points.shape[-1] == 4:    sampling_locations = reference_points[:, :, None, :, None, :2] \        + sampling_offsets / self.num_points \        * reference_points[:, :, None, :, None, 2:] \        * 0.5else:    raise ValueError(        f'Last dim of reference_points must be'        f' 2 or 4, but get {reference_points.shape[-1]} instead.')if torch.cuda.is_available() and value.is_cuda:    # using fp16 deformable attention is unstable because it performs many sum operations    if value.dtype == torch.float16:        MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32    else:        MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32    output = MultiScaleDeformableAttnFunction.apply(        value, spatial_shapes, level_start_index, sampling_locations,        attention_weights, self.im2col_step)else:    #output:torch.Size([1, 900, 256])     #可变形注意力机制,利用query从value(bev_embed)中提取有用信息    output = multi_scale_deformable_attn_pytorch(        value, spatial_shapes, sampling_locations, attention_weights)        #output:torch.Size([1, 900, 256])output = self.output_proj(output)if not self.batch_first:    # (num_query, bs ,embed_dims)    output = output.permute(1, 0, 2)return self.dropout(output) + identity

3 cls_branches&®_branches

功能:

  • 利用decoder输出的bev_embed, inter_states(6层输出的outs), init_reference_out(由query_pos生成的初始reference_points), inter_references_out(6层输出的reference_points)生成目标类别和bboxes;
  • 生成包含bev_embed、 all_cls_scores、all_bbox_preds在内的outs,其中 all_cls_scores、all_bbox_preds用于计算Loss、梯度回传;bev_embed可用于segmentation等任务,进行BEV视角下的语义分割。

解析:

#以下变量的含义见《BEVFormer开源算法逐行解析(一):Encoder部分》bs, num_cam, _, _, _ = mlvl_feats[0].shapedtype = mlvl_feats[0].dtypeobject_query_embeds = self.query_embedding.weight.to(dtype)bev_queries = self.bev_embedding.weight.to(dtype)bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),                        device=bev_queries.device).to(dtype)bev_pos = self.positional_encoding(bev_mask).to(dtype)if only_bev:  # only use encoder to obtain BEV features, TODO: refine the workaround    return self.transformer.get_bev_features(        mlvl_feats,        bev_queries,        self.bev_h,        self.bev_w,        grid_length=(self.real_h / self.bev_h,                        self.real_w / self.bev_w),        bev_pos=bev_pos,        img_metas=img_metas,        prev_bev=prev_bev,    )else:    #outputs就是object_query_embeds、bev_pos、bev_queries、img_metas和mlvl_feats    #输入encoder和decoder模块后的最终输出    #outputs:bev_embed, inter_states, init_reference_out, inter_references_out    outputs = self.transformer(        mlvl_feats,        bev_queries,        object_query_embeds,        self.bev_h,        self.bev_w,        grid_length=(self.real_h / self.bev_h,                        self.real_w / self.bev_w),        bev_pos=bev_pos,        reg_branches=self.reg_branches if self.with_box_refine else None,  # noqa:E501        cls_branches=self.cls_branches if self.as_two_stage else None,        img_metas=img_metas,        prev_bev=prev_bev)bev_embed, hs, init_reference, inter_references = outputshs = hs.permute(0, 2, 1, 3)outputs_classes = []outputs_coords = []for lvl in range(hs.shape[0]):    if lvl == 0:        reference = init_reference    else:        reference = inter_references[lvl - 1]    reference = inverse_sigmoid(reference)    #outputs_class:torch.Size([1, 900, 10])    outputs_class = self.cls_branches[lvl](hs[lvl])    #tmp:torch.Size([1, 900, 10])    tmp = self.reg_branches[lvl](hs[lvl])    # TODO: check the shape of reference    assert reference.shape[-1] == 3    tmp[..., 0:2] += reference[..., 0:2]    tmp[..., 0:2] = tmp[..., 0:2].sigmoid()    tmp[..., 4:5] += reference[..., 2:3]    tmp[..., 4:5] = tmp[..., 4:5].sigmoid()    #下面" *(self.pc_range[3] -self.pc_range[0]) + self.pc_range[0]",    #是为了将目标bboxes中心点x、y、z坐标恢复到实际尺度    tmp[..., 0:1] = (tmp[..., 0:1] * (self.pc_range[3] -                        self.pc_range[0]) + self.pc_range[0])    tmp[..., 1:2] = (tmp[..., 1:2] * (self.pc_range[4] -                        self.pc_range[1]) + self.pc_range[1])    tmp[..., 4:5] = (tmp[..., 4:5] * (self.pc_range[5] -                        self.pc_range[2]) + self.pc_range[2])    # TODO: check if using sigmoid    outputs_coord = tmp    outputs_classes.append(outputs_class)    outputs_coords.append(outputs_coord)#outputs_classes:torch.Size([6, 1, 900, 10])outputs_classes = torch.stack(outputs_classes)#outputs_coords:torch.Size([6, 1, 900, 10])outputs_coords = torch.stack(outputs_coords)outs = {    'bev_embed': bev_embed,    'all_cls_scores': outputs_classes,    'all_bbox_preds': outputs_coords,    'enc_cls_scores': None,    'enc_bbox_preds': None,}#outs输出后就可以和class_labels和bboxe_labels一起计算Loss,#然后反向传播梯度,更新模型中的可学习参数:#各个线性层、object_query_embeds、bev_queries、bev_pos等return outs

深色代码部分生成的tmp[0:2]和tmp[4:5]结构见下图,实质上就是"DetectionTransformerDecoder"中生成的reference_points。

结语:

至此,BEVFormer中的Encoder和Decoder部分的逐行代码解析就完成了,如果后续有需求也可以再出一期关于解析Loss计算的文档,这部分比较基础,有兴趣的同学也可以先结合源码自学。

标签:reference,self,Det,BEVFormer,query,embed,bev,逐行,proj
From: https://www.cnblogs.com/horizondeveloper/p/18391051

相关文章