源码解析
论文阅读
1、数据预处理
def forward(self, data: Dict):
out = {}
frames = data['rgb'] #输入的RGB帧序列
first_frame_gt = data['first_frame_gt'].float() #第一帧的ground truth
b, seq_length = frames.shape[:2] #批量大小和序列长度
# 在数据集预处理过程中,选取的对象数量数不超过train_config中num_objects参数的值
num_filled_objects = [o.item() for o in data['info']['num_objects']] #每个样本中实际存在的目标数量,将其转换为python列表
max_num_objects = max(num_filled_objects) #最大目标数量
first_frame_gt = first_frame_gt[:, :, :max_num_objects] #根据最大目标数量裁剪第一帧的gt标注
selector = data['selector'][:, :max_num_objects].unsqueeze(2).unsqueeze(2) #用于选择特定目标的张量
num_objects = first_frame_gt.shape[2]
out['num_filled_objects'] = num_filled_objects
def get_ms_feat_ti(ti):
return [f[:, ti] for f in ms_feat] #提取ms_feat中提取第ti时刻的特征
with torch.cuda.amp.autocast(enabled=self.use_amp):
frames_flat = frames.view(b * seq_length, *frames.shape[2:])
数据预处理如代码和图所示,最开始的输入数据data是一个字典类型,它包含以下五个变量,这里只说最重要的三个变量。
- 输入数据(data:Dict)
- rgb:输入的原始视频帧,shape为(b, t, c, h, w)
- first_frame_gt:第一帧视频帧的注释mask,shape为(b, num_ojects, c, h, w)。
- 这里的num_objects是train_config中设置的参数值(具体位置是cutie/cutie/config/train_config.yaml)
- num_objects代表model预先设置的能接受的最大对象数+1,这里的1是背景
- 在生成第一帧gt图时,会根据num_objects设置shape大小。若原始视频的对象数小于num_objects-1,则后面几个维度设为空。若原始视频的对象数大于num_objects-1,则随机挑选num_objects-1个对象进行训练。
- selector:shape为(1, num_objects)
代码对输入数据的处理为:
- 输入数据处理
- 对rgb图像进行展平处理,并存储为变量frames_flat
- 对first_frames_gt按照视频中出现过的最大对象数进行裁段
- 对selector按照视频中出现过的最大对象数进行裁段,并扩充维度
2、视频帧特征提取
2.1 pixel encoder 特征提取
- pixel encoder
- resnet50的前三层,每一层输出一个特征图
- 输出为ms_feat(resnet50的前三层特征图)以及pix_feat(对第三层特征图的通道维度进行压缩)
with torch.cuda.amp.autocast(enabled=self.use_amp): #按照预设值决定是否启用混合精度计算
frames_flat = frames.view(b * seq_length, *frames.shape[2:]) #将frames展平,由原来的b, t, c,h, w变为b*t, c, h, w
ms_feat, pix_feat = self.encode_image(frames_flat)
def encode_image(self, image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
image = (image - self.pixel_mean) / self.pixel_std
ms_image_feat = self.pixel_encoder(image)
return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
'self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)'
2.2 tranformer_key
对ms_feat的第三个特征图进行关键特征提取:
- 先统一压缩至256维
- key特征图:将256维度的特征图再压缩至64维
- shrinkage:将256维度的特征图压缩维1维
- selection:将256维的特征图压缩至64维后,再输入进sigmoid函数
with torch.cuda.amp.autocast(enabled=False): #禁止混合精度计算,确保transformer_key中的所有操作全部以精度运行
keys, shrinkages, selections = self.transform_key(ms_feat[0].float())
class KeyProjection(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
in_dim = model_cfg.pixel_encoder.ms_dims[0] #1024
mid_dim = model_cfg.pixel_dim #256
key_dim = model_cfg.key_dim #64
self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
# shrinkage
self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
# selection
self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
nn.init.orthogonal_(self.key_proj.weight.data)
nn.init.zeros_(self.key_proj.bias.data)
def forward(self, x: torch.Tensor, *, need_s: bool,
need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
x = self.pix_feat_proj(x)
shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
return self.key_proj(x), shrinkage, selection
2.3 特征图维度转换
将上述得到的所有特征图进行维度转换,将原来展平的时间维度重新提取出来。
h, w = keys.shape[-2:]
keys = self.move_t_from_batch_to_volume(keys)
shrinkages = self.move_t_from_batch_to_volume(shrinkages)
selections = self.move_t_from_batch_to_volume(selections)
ms_feat = [self.move_t_out_of_batch(f) for f in ms_feat]
pix_feat = self.move_t_out_of_batch(pix_feat)
'self.move_t_out_of_batch = Rearrange((b t) c h w -> b t c h w, t=self.seq_length)'
标签:dim,num,objects,2024CVPR,self,源码,ms,解析,feat
From: https://blog.csdn.net/weixin_43571113/article/details/145264086