首页 > 其他分享 >ALBEF-ITC损失部分

ALBEF-ITC损失部分

时间:2023-11-21 21:24:51浏览次数:25  
标签:dim idx text image ITC 损失 ALBEF feat self

Align before Fuse: Vision and Language Representation Learning with Momentum Distillation

引言

VLP目标是从大规模图片-文本对子中学习到多模态表示,一次改进下游的视觉-语言任务。

VLP框架的局限性如下:

  1. 图片特征和文字token嵌入在它们各自的空间内,使得多模态encoder难以去学习它们之间的关系。
  2. 目标decoder既需要大量标注也需要大量的计算资源,因为其在预训练时候需要边界框的标注和高分辨率的图片(600x1000)。
  3. 许多图片-文本数据集来自于网络,通常包含噪声,导致如MLM等模型可能会拟合噪声文本,降低模型的泛化能力。

 

作者提出了一种新的VLP框架ALBEF,以此解决以上问题:

1. 跨模态注意力:

作者首先使用detector-free的图片encoder(不需要检查特征点,直接匹配)和文本encoder对图片和文本编码。

然后使用多模态编码器通过跨模态注意力去融合图片特征和文本特征。

 

2. 作者提出了图片-文本对比(ITC)损失:

对齐图片特征和文本特征,使得其更容易用于多模态编码器执行跨模态学习。

帮助单模态编码器更好的理解图片和文本的语义

学习一个低维空间去嵌入图片和文本,可以使得图片-文本匹配目标挖掘更多有信息的样本。

 

3. 为了在噪声监督下学习,作者还提出了动量蒸馏(MoD):

在训练期间,通过获取模型的参数的移动平均值,保持模型的一个动量版本。然后使用动量模型生成伪目标作为额外的监督。

MoD模型不会因为产生与网络注释不同的输出而受到惩罚。

MoD不仅改进了预训练,也对下游任务的标注进行清洗。

 

方法

图-文对比学习

首先,图片编码器和文本编码器都会在图片序列和文本序列的首部加上[CLS]标签,表示学习到的图片全局表示。

之后的对比就是基于[CLS]向量的对比。

图片和文本的[CLS]分别用vclswcls表示,动量编码器的输出特征分别使用g'w(w'cls)和g'v(v'cls)表示

对比学习,是学习与动量编码器输出的相似度。

s(I,T)=gv(vcls)T g'w(w'cls)

s(T,I)= gw(wcls)Tg'v(v'cls)

对于每个图片和文本,计算归一化的图片对文本的相似度和文本对图片的相似度。

τ是温度超参数。Tm是动量编码器输出的所有图片的[CLS],Im是动量编码器输出的所有文本[CLS]。

图文对比学习损失ITC如下:

其中H为交叉熵损失,y为Ground Truth标签。(在实际预训练中,代码中y采用的是伪标签)

已知交叉熵损失

代入ITC损失,得到

其中预测概率p为

  

其中s(I,T)是当前Image与一个Text的相似度。最终需要计算当前Image与所有Text的相似度,所以在源码中,是直接计算I与动量编码器Text队列中所有的Text的相似度。s(T,I)也是如此。

代入到ITC损失中得到源代码中的计算公式(对应源代码中不蒸馏部分)

#源码中,映射图片和文本的全连接层输出embedding_dim为256
#一批输入中,有N个图片和N个句子
#图片和文本队列大小都为57600,维度为256,也就是是可以保存57600个维度为256的[CLS]
#为了存储方便,队列形状设置为256 x 57600
image_embeds = self.visual_encoder(image) 
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 
#image_feat形状为:(N,256)
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,return_dict = True, mode = 'text')            
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)               
#text_feat形状为:(N,256)

idx = idx.view(-1,1)
#idx为图片-文本对的标签,分为一致2,中性1,对立0。
#原本形状为(N,),现在变为(N,1)
#idx转置成形状(1,N),idx_queue形状为(1,57600)
#然后将idx拼接到队列的头部得到idx_all,形状为(1,N+57600)
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)  
pos_idx = torch.eq(idx, idx_all).float()
#idx形状为(N,1),idx_all形状为(1,N+57600)
#比较之后,比较矩阵为(N,N+57600),表示N个标签分别与N+57600个的比较结果。
#由于队列的头部是新添加的标签,新标签与其比较时,自然而然对角线为1。
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) #硬标签

with torch.no_grad():
    self._momentum_update()#更新动量编码器
    image_embeds_m = self.visual_encoder_m(image) 
    image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
    image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                                         
    #image_feat_m转置后形状为:256 x 2 , Image队列的形状为256 x 57600
    #上述拼接操作是将队列复制一份,并将image_feat_m拼接到队列的头部!。
    
    text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,return_dict = True, mode = 'text')    
    text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
    text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
    #文本也一致    
    #text_feat_all和image_feat_all分别为text队列和image队列中所有的[CLS]集合

#计算图文特征分别对队列中所有特征的相似度
sim_i2t = image_feat @ text_feat_all / self.temp 
sim_t2i = text_feat @ image_feat_all / self.temp

loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean()#计算与硬标签的交叉熵损失
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean()

loss_ita = (loss_i2t+loss_t2i)/2

self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)

 动量蒸馏图文对比损失

带蒸馏则需要动量编码器输出的新样本与队列中所有样本的相似度

然后最小化q和p之间的KL散度

代入原始式子

最小化KL散度的等价关系如下

最小化原式子等价于

得到如下公式(对应于源代码中的式子):

 

源代码如下:

#源码中,映射图片和文本的全连接层输出embedding_dim为256
#一批输入中,有N个图片和N个句子
#图片和文本队列大小都为57600,维度为256,也就是是可以保存57600个维度为256的[CLS]
#为了存储方便,队列形状设置为256 x 57600
image_embeds = self.visual_encoder(image) 
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 
#image_feat形状为:(N,256)
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,return_dict = True, mode = 'text')            
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)               
#text_feat形状为:(N,256)

idx = idx.view(-1,1)
#idx为图片-文本对的标签,分为一致2,中性1,对立0。
#原本形状为(N,),现在变为(N,1)
#idx转置成形状(1,N),idx_queue形状为(1,57600)
#然后将idx拼接到队列的头部得到idx_all,形状为(1,N+57600)
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)  
pos_idx = torch.eq(idx, idx_all).float()
#idx形状为(N,1),idx_all形状为(1,N+57600)
#比较之后,比较矩阵为(N,N+57600),表示N个标签分别与N+57600个的比较结果。
#由于队列的头部是新添加的标签,新标签与其比较时,自然而然对角线为1。
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
        
with torch.no_grad():
    self._momentum_update()#更新动量编码器
    image_embeds_m = self.visual_encoder_m(image) 
    image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
    image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                                         
    #image_feat_m转置后形状为:256 x 2 , Image队列的形状为256 x 57600
    #上述拼接操作是将队列复制一份,并将image_feat_m拼接到队列的头部!。
    
    text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,return_dict = True, mode = 'text')    
    text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
    text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
    #文本也一致
    
    #text_feat_all和image_feat_all分别为text队列和image队列中所有的[CLS]集合
    #动量蒸馏,创建软标签
    sim_i2t_m = image_feat_m @ text_feat_all / self.temp 
    sim_t2i_m = text_feat_m @ image_feat_all / self.temp   
    sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
    sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 

#计算图文特征分别对队列中所有特征的相似度
sim_i2t = image_feat @ text_feat_all / self.temp 
sim_t2i = text_feat @ image_feat_all / self.temp

#动量蒸馏,计算与软标签的等价KL散度。
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 
    
loss_ita = (loss_i2t+loss_t2i)/2

self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)

 

标签:dim,idx,text,image,ITC,损失,ALBEF,feat,self
From: https://www.cnblogs.com/RedNoseBo/p/17845565.html

相关文章

  • FreeSWITCH模块开发
    FreeSWITCH内核开发 1FreeSWITCH模块加载流程1.1main函数的主要工作FreeSWITCH在main函数中除了初始化异常处理程序,解析软交换启动参数(比如:-nc-nonat-conf-db等)之外,其核心就是调用switch_core_init_and_modload()函数初始化FreeSWITCH内核以及加载外围模块,具体加载哪些模......
  • DDOS攻击,流量超导致经济损失,考虑是否自己托管服务器,不要用云服务器
    云服务器还是要小心,流量超标。erwa前几天阿里云宕掉了。考虑有些大一点的公司在考虑是否自己托管服务器,不要用云服务器。云服务器,续费涨价,故障,流量,扩展等问题,根据自己的需要考虑是否采用。   ......
  • Switch选择结构 反编译待完善
     ......
  • 无线信道-路径损失以及信道衰落
    看了很多论文有关无线的论文,一直对他的论文里的信道模型很迷惑,大体结合搜到的资料以及论文整理一下。1、衰落\(\quad\)无线通信里,信号强度的变化可以分为大尺度衰落(Large-scalefading)和小尺度衰落(Small-scalefading),这两者由不同的物理现象引起,并在不同的尺度上影响信号。(1)大......
  • switch(jdk8)
     本质字节码int类型1int=4bytepublicstaticvoidswitchTest(inta){switch(a){case1:System.out.println("1");break;case2:System.out.println("2"......
  • 【Java基础】Java中switch的多种写法
    Java中switch的多种写法代码需求:键盘录入一个数字(代表星期几),判断是工作日还是休息日switch最基础写法 publicstaticvoidswitchTest(){while(true){System.out.println("请输入:");Scannersc=newScanner(System.in);......
  • 论文精读:用于少样本目标检测的元调整损失函数和数据增强(Meta-tuning Loss Functions a
    论文链接:Meta-TuningLossFunctionsandDataAugmentationforFew-ShotObjectDetectionAbstract现阶段的少样本学习技术可以分为两类:基于微调(fine-tuning)方法和基于元学习(meta-learning)方法。基于元学习的方法旨在学习专用的元模型,使用学到的先验知识处理新的类,而基于微......
  • 损失函数波动不收敛
      1.数据集不同类别样本数据不均匀,导致的 ......
  • 损失函数---训练集降低,验证集升高
     损失函数在训练集下降而在验证集上升,通常被称为过拟合(overfitting)的现象。这意味着模型在训练数据上表现得很好,但在新的、未见过的数据上表现较差。过拟合可能是由于模型过于复杂,以至于学到了训练数据中的噪声或细微特征,而这些特征在验证数据中并不普遍存在。 我通过降低学......
  • 损失函数Loss越来越大
     代表什么:预测值和真实值越来越大,模型效果不好 为啥?#classMLPModel(nn.Module):#def__init__(self,input_size):#super(MLPModel,self).__init__()#self.fc1=nn.Linear(input_size,128)#self.fc2=nn.Linear(128,64)#......