首页 > 其他分享 >FD-Align: Feature Discrimination Alignment for Fine-tuning Pre-Trained Models in Few-Shot Learning

FD-Align: Feature Discrimination Alignment for Fine-tuning Pre-Trained Models in Few-Shot Learning

时间:2024-07-08 14:30:31浏览次数:27  
标签:Pre 虚假 Shot tuning CLIP 特征 模型 微调 训练

文章汇总

动机

image.png
CLIP注意图更关注背景,全面微调后的CLIP关注在了非显著特征的地方。FD-Align注意图倾向于关注标签相关的信息。

解决办法

image.png
总损失有两个损失函数组成:
image.png

对Visual Encoder进行微调

冻结CLIP的文本编码器 g 0 g_0 g0​,并预训练CLIP 的视觉编码器 f 0 f_0 f0​, f t f_t ft​提取图像 x x x的特征为 f t ( x ) f_t(x) ft​(x)。
对于每个类别 y y y(总共 N N N个类别),有 M M M个提示模板,即 [ P 1 , y ] , . . . , [ P M , y ] [P_1,y],...,[P_M,y] [P1​,y],...,[PM​,y]。作者对 M M M个提示模板取了个平均,故 y y y类的文本原型表示为
image.png
视觉特征和文本原型特征作交叉熵损失计算
image.png
相似度计算 s ( ⋅ , ⋅ ) s(\cdot,\cdot) s(⋅,⋅)为余弦相似度。

缓解微调CLIP会对未知数据的鲁棒性影响

对于每个提示模版 P j P_j Pj​在所有类(框架图为 N N N个类)上取均值
image.png
微调后模型的视觉特征与 u P j s p u r i o u s , j = 1 , . . . , M u^{spurious}_{P_j},j=1,...,M uPj​spurious​,j=1,...,M算相似度,挑选出最佳的提示模版 P 1 P_1 P1​
image.png
同理,冻结的视觉特征与 u P j s p u r i o u s , j = 1 , . . . , M u^{spurious}_{P_j},j=1,...,M uPj​spurious​,j=1,...,M算相似度,挑选出最佳的提示模版 P 2 P_2 P2​
image.png
P 1 , P 2 P_1,P_2 P1​,P2​计算损失
image.png

摘要

由于数据的可用性有限,现有的小样本学习方法不能达到令人满意的效果。相比之下,大规模的预训练模型,如CLIP,展示了显著的few-shot和zero-shot能力。为了提高下游任务的预训练模型的性能,经常需要对下游数据的模型进行微调。然而,在分布移位的情况下,对预训练模型进行微调会导致其泛化能力下降,而在少数样本学习中,有限的样本数量使模型极易出现过拟合的情况。因此,现有的微调小样本学习方法主要集中在微调模型的分类头或引入额外的结构。在本文中,我们介绍了一种称为特征识别对齐(FD-Align)的微调方法。我们的方法旨在通过在整个微调过程中保持伪特征的一致性来增强模型的可泛化性。大量的实验结果验证了我们的方法对ID和OOD任务的有效性。经过微调后,该模型可以与现有方法无缝集成,从而提高性能。我们的代码可以在https://github.com/skingorz/FD-Align中找到。

1.介绍

对比语言图像预训练模型(CLIP)[1]代表了多模态深度学习的突破性发展。通过对比学习,CLIP在统一的嵌入空间内对齐视觉和文本表示,并在各种下游任务中表现出优异的性能,包括图像分类[2,3]、检测[4]和分割[5,6],这些任务通常是通过使用下游数据对CLIP进行完全微调来完成的。然而,在许多现实场景中,可用的标记数据量通常不足。因此,完全微调将导致过拟合并显著降低模型性能。
为了缓解这一挑战,我们考虑使用与下游目标数据集相关的代理数据集对CLIP进行微调,旨在获得一个能够有效泛化到少数目标任务的模型。在代理数据集上直接完全微调CLIP是不可行的,因为微调后的模型可能会对代理数据过拟合或具有较差的out-of-distribution (OOD)泛化[7],从而限制了其在目标任务上的性能。如图1中所示的示例所示,与原始CLIP相比,完全微调的CLIP倾向于更多地关注局部区域,而较少关注前景。
这种局部关注会削弱模型对虚假相关的鲁棒性[8],导致完全微调后的CLIP的OOD泛化效果较差。
在本文中,我们的目标是在微调期间保持CLIP对虚假相关的鲁棒性,即其区分虚假和因果特征的能力。特别是,因果特征代表与类相关的特征,而虚假特征可能是与类的上下文相关的特征。我们希望经过微调的CLIP既能学习到新类的因果特征,又能保持对伪特征的识别能力。为此,我们提出了一种特征判别对齐方法(FD-Align)。具体来说,我们引入了一个伪特征分类器,确保伪特征的分类概率分布在整个微调过程中保持一致。利用CLIP文本和视觉特征的强大对齐能力,我们利用与类别无关的描述(即上下文)的文本特征作为伪特征原型。对图像特征和伪特征原型进行相似性度量,以确定当前图像在伪特征上的概率分布。通过约束模型提取的图像特征在微调前后的概率分布,保证了模型提取的伪特征的一致性。同时,在学习代理数据集的分类能力的同时,也保证了模型在微调后对虚假关联的鲁棒性。
image.png
图1:(a)狗的图像。(b) CLIP注意图,与狗相比,它更关注背景。©经过全面微调后的CLIP注意图,它更多地关注具有非显著特征的地点。(d) FD-Align调优后的注意图,它倾向于优先考虑狗的因果信息,同时也注意到一小部分虚假信息。
我们的方法在对代理数据集进行微调的同时保持了模型对虚假相关性的鲁棒性。如图1d所示,一方面,与CLIP相比,带有FD-Align的微调模型更好地聚焦于狗。另一方面,与完全微调CLIP的局部注意相比,FD-Align对一些虚假信息的注意较少。这种关注因果信息和虚假信息之间的平衡确保了模型对虚假相关的鲁棒性,从而确保了模型的OOD泛化。大量的实验验证了我们的方法的鲁棒OOD性能,以及在分布(ID)性能的改进。此外,如图2所示,FD-Align微调的模型直接提高了现有方法的准确性,而不会引入额外的推理成本。
image.png
本文的贡献如下:(1)提出利用文本特征获取图像的伪特征;(2)提出了一种特征判别对齐微调架构,通过对微调前后的伪特征提取模型进行对齐,保证了微调模型的OOD性能;(3)充分的实验表明,我们的方法可以显著提高ID和OOD在小样本学习中的性能,并且可以在不引入额外训练和推理成本的情况下提高现有方法的性能。

2.相关工作

小样本学习。小样本学习的主要目标是用少量样本训练出性能优异的模型。先前的方法主要是在基础数据上训练模型,并在没有任何共享类别的新类数据上评估它们的性能。方法比如MAML[9]采用元学习在基础数据上训练基础学习器,然后在有限数量的新数据上对其进行微调,得出适合新数据的模型。ProtoNet[10]介绍了使用度量学习进行训练。近年来,有一些研究开始将文本的模态引入到小样本图像分类中,如AM3[11]、TRAML[12]、FILM[13]等。所有这些模型都是从零开始训练的。随着预训练的双峰模型(如CLIP)的出现,利用这些预训练的模型可以获得更精确的图像特征。因此,目前的研究主要集中在如何利用CLIP提取的特征来增强小样本学习的能力。例如,CoOp[14]和CoOp[15]模型用可学习的向量提示上下文词,保持所有预训练参数固定。TipAdapter[3]和APE[16]不需要任何反向传播来训练适配器,而是通过从少量训练集构建的键值缓存模型来创建权重。VPT[17]在输入空间中引入了额外的可学习参数。然而,这些方法都是在冻结backbone的情况下进行处理的,而本文的目的是进一步探索对backbone本身进行微调的可能性。虽然这些方法表现出优异的性能,但它们没有利用预训练模型的潜力。
预训练模型的微调。最直接的方法是直接对预训练模型进行微调。然而,完全微调会降低预训练模型的OOD性能[7]。WiSE-FT[18]通过整合zero-shot和微调模型的权重来提高性能。Kumar等[7]首先进行线性探测,然后进行全面微调,以确保模型的OOD性能。Xuhong等[19]引入了一个额外的正则器来约束零射击和微调CLIP之间的l2距离。另一方面,Mukhoti等[20]通过在特征空间中约束zero-shot和微调CLIP提取的图像特征,防止了基础模型能力的退化。然而,现有的方法很少深入研究使用有限数据对预训练模型进行微调的策略。
虚假的相关性。虚假相关性表示在训练过程中大多数情况下具有欺骗性的启发式,但并不总是适用[21]。例如,当训练实例涉及草地上的奶牛时,很容易将草地的存在误解为对奶牛进行分类的因果决定因素。这种将草视为牛的误解体现了一种虚假的相关性。消除虚假特征通常会提高OOD性能;然而,这可能伴随着ID性能的下降[22]。此外,即使对虚假特征有了全面的了解,其提取仍然是一项不平凡的任务[23]。虚假相关性的影响在多模态预训练模型中持续存在[24]。然而,当模型参数和数据集规模足够大时,Vision Transformer (ViT)对虚假相关性的鲁棒性增强[8]。此外,CLIP架构增强了视觉编码器对伪相关的鲁棒性[25]。因此,提高对伪相关的鲁棒性可以有效地保持模型的OOD性能。

3.提出的方法

image.png
图3:将类名和提示符组合并输入到文本编码器中,以获得文本嵌入。我们分别计算提示和类维的均值,推导出类原型和提示嵌入。一方面,利用微调后的视觉编码器提取图像特征,并基于类原型计算类分布,计算类损失;另一方面,我们使用伪原型校正(SPC)模块对提示嵌入进行校正。通过计算图像特征与杂散原型之间的余弦相似度,得到杂散特征的分布并计算杂散损失。

3.1.问题定义

假设我们有一个预训练的CLIP[1],它包含一个视觉编码器 f 0 f_0 f0​和一个文本编码器 g 0 g_0 g0​。此外,我们可以访问小样本代理数据集 D ⊂ X × Y D\subset \mathcal{X} \times \mathcal{Y} D⊂X×Y,其中每个类都有非常有限的样本,每个样本包括图像 x x x及其相应的标签 y y y。目标是使用该代理数据集微调预训练的CLIP,旨在提高其zero-shot性能,以实现与代理数据集相关的未见目标任务。

3.2.代理数据集的微调

我们在微调期间冻结CLIP的文本编码器 g 0 g_0 g0​,并使视觉编码器可学习。首先,我们使用预训练CLIP f 0 f_0 f0​的视觉编码器参数初始化视觉编码器 f t f_t ft​。然后使用视觉编码器 f t f_t ft​提取图像 x x x的特征 f t ( x ) f_t(x) ft​(x)。借助CLIP出色的文本视觉对齐能力,我们使用每个类的类名的文本特征作为类原型。在CLIP之后,对于任何类 y y y,我们组合 M M M个提示模板 ( P 1 , . . . , P M ) (P_1,...,P_M) (P1​,...,PM​)与类名,并获得 M M M个提示 [ P 1 , y ] , . . . , [ P M , y ] [P_1,y],...,[P_M,y] [P1​,y],...,[PM​,y]。然后,我们使用文本编码器 g 0 g_0 g0​提取上述 M M M个提示符的特征。随后,我们计算了的 M M M个特征的平均数得到对应类的原型,即 y y y类的原型为
image.png
我们计算图像特征和类原型之间的余弦相似度 s ( ⋅ , ⋅ ) s(\cdot,\cdot) s(⋅,⋅),并生成图像的类分布。最后,利用交叉熵损失计算类损失
image.png
其中 Y \mathcal{Y} Y是标签集。

3.3虚假特征约束

在代理数据上完全微调CLIP会影响模型对未知数据的鲁棒性。为了在微调期间保持模型对分布外数据的性能,我们在微调期间保持模型对虚假相关的鲁棒性。即保持模型在微调前后提取的杂散特征不变。我们首先计算每个提示模板 P j P_j Pj​在作为提示模板 P j P_j Pj​原型的所有类上的特征均值,即:
image.png
我们可以计算由微调模型提取的特征与虚假原型之间的相似度,并得出虚假特征的分布如下
image.png
类似地,我们可以使用预训练的视觉编码器f0来提取特征,并产生以下虚假特征的分布。
image.png
我们可以通过保持模型在伪特征上的概率分布在微调前后一致来保证模型在微调前后的伪特征保持一致,即
image.png
最后,我们在微调过程中对(1)和(2)进行优化,以保证分类能力和OOD鲁棒性:
image.png
在本文中我们将 α \alpha α设为1, β \beta β设为20。

3.4.伪原型校正(Spurious Prototype Correction)

提示模板通常是手动设计的,或者由大型语言模型(如GPT)生成。这些模板通常包含冗余或不合逻辑的提示。因此,这些不精确和冗余的提示模板计算的伪特征原型缺乏准确性。因此,对这些伪特征原型进行滤波和处理是必要的。
某些提示模板可能缺乏实际意义或不合理,使其不适合作为虚假特征合并,例如“itap of {class}”。这种情况可能导致虚假原型的不准确性。为了解决这个问题,我们采用隔离森林算法[26]来消除与虚假特征相关的无意义原型,即 µ s p u r i o u s : = ISOLATIONFOREST ( µ s p u r i o u s , n ) µ^{spurious}:= \text{ISOLATIONFOREST}(µ^{spurious}, n) µspurious:=ISOLATIONFOREST(µspurious,n)。我们将保留 n n n个表现出最高程度合理性的原型。
此外,在某些情况下,某些提示会表现出过度的相似性。例如,提示““a photo of a {class}”,“a photo of my {class}”和“a photo of the {class}.”。显示出显著的相似之处。在这种情况下,一条虚假信息可能对应多个提示。然而,一些虚假的信息只与一个提示一致。因此,在分类过程中,与单个提示相对应的虚假信息概率的相对权重减小。为了解决这个问题,我们采用k-means算法来合并由类似提示产生的重复虚假特征,即 µ ~ s p u r i o u s : = k − m e a n s ( µ s p u r i o u s , k ) \tilde µ^{spurious}:= k-means(µ^{spurious}, k) µ~spurious:=k−means(µspurious,k),其中 k k k是聚类中心的数量。

4.实验

image.png
image.png

5.结论

在本文中,我们引入了一种特征判别对齐微调(FD-Align)方法,在小样本学习中对模型进行预训练。利用CLIP出色的文本-视觉对齐功能,我们使用与类别无关的描述的文本特征作为伪特征原型。此外,我们在微调前后对模型提取的图像特征在伪特征上的概率分布进行约束,以保证模型在微调后的鲁棒性。实验结果证实了我们的方法在提高微调性能的同时确保跨分布的鲁棒性的有效性。

参考资料

论文下载(NeurIPS 2023)

https://arxiv.org/pdf/2310.15105

image.png

代码地址

https://github.com/skingorz/FD-Align

标签:Pre,虚假,Shot,tuning,CLIP,特征,模型,微调,训练
From: https://blog.csdn.net/weixin_50917576/article/details/140267654

相关文章

  • Navicat 推出免费精简版 —— Navicat Premium Lite
    2024年6月25日,Navicat宣布推出免费的数据库管理开发工具——NavicatPremiumLite。针对入门级用户,支持基础的数据库管理和协同合作功能。NavicatPremiumLite下载地址:https://www.navicat.com.cn/products/navicat-premium-lite官方介绍如下:NavicatPremiumLite是......
  • PREEMPT_RT 内核是如何实现其实时性的
    PREEMPT_RT内核是通过以下几个关键机制来实现实时性的:抢占式内核调度器:PREEMPT_RT内核使用了抢占式的调度器,可以及时中断正在运行的进程,并立即切换到更高优先级的实时进程执行。这与标准内核的协作式调度器不同,后者只有在进程主动放弃CPU时才能切换到其他进程。中断路径......
  • Fundamentals of Machine Learning for Predictive Data Analytics Algorithms, Worke
    主要内容:本书介绍了机器学习在预测数据分析中的基本原理、算法、实例和案例研究,涵盖了从数据到决策的整个过程。书中涉及机器学习项目生命周期的各个方面,包括数据准备、特征设计和模型部署。结构:本书分为五个部分,共计14章和若干附录:引言(IntroductiontoMachineLearn......
  • CosyVoice多语言、音色和情感控制模型,one-shot零样本语音克隆模型本地部署(Win/Mac),
    近日,阿里通义实验室开源了CosyVoice语音模型,它支持自然语音生成,支持多语言、音色和情感控制,在多语言语音生成、零样本语音生成、跨语言声音合成和指令执行能力方面表现卓越。CosyVoice采用了总共超15万小时的数据训练,支持中英日粤韩5种语言的合成,合成效果显著优于传统语音合成模......
  • 07浅谈大语言模型可调节参数tempreture
    浅谈temperature什么是temperature?temperature是大预言模型生成文本时常用的两个重要参数。它的作用体现在控制模型输出的确定性和多样性:控制确定性:temperature参数可以控制模型生成文本的确定性,大部分模型中temperature取值范围为(0-1]。接近0时,模型倾向于选择概率最......
  • Node.js之Express
    Express介绍Express是一个简洁、灵活的node.jsWeb应用开发框架,是目前最流行的基于Node.js的Web开发框架.它提供一系列强大的功能,比如:模板解析静态文件服务中间件路由控制还可以使用其他模块来帮助你创建各种Web和移动设备应用使用express本地安装$npminstallexp......
  • 好消息!数据库管理神器 Navicat 推出免费精简版:Navicat Premium Lite
    前言好消息,前不久Navicat推出了免费精简版的数据库管理工具NavicatPremiumLite,可用于商业和非商业目的,我们再也不需要付费、找破解版或者找其他免费平替工具了,有需要的同学可以马上下载使用起来。工具官方介绍NavicatPremiumLite是Navicat的精简版,它包含了用户执行主要......
  • 用免费WordPress和Cloudflare打造媲美收费服务的网站
    你是否曾因为网站搭建的高昂费用而犹豫不决?别担心,我来告诉你一个几乎零成本的解决方案,让你轻松拥有一个功能强大的网站。通过免费域名、免费PHP主机、WordPress程序和CloudflareCDN服务的组合,你可以打造出一个媲美收费服务的网站。首先,你需要一个域名。在lita.eu.org注册免费......
  • wx.config的前后端实现express和react
    wx.config是微信JS-SDK的配置接口,用于初始化微信JS-SDK。为了确保安全性,微信要求每次调用JS-SDK时都需要进行签名认证。签名认证需要使用jsapi_ticket,而jsapi_ticket需要通过access_token获取。以下是实现wx.config的步骤:后端部分获取AccessToken你需要定期获取并缓存acce......
  • ContentPresenter 的作用
    我发现WPF自定义控件模板的时候有时候写ContentPresenter,有时候不写,不管写不写ContentPresenter都能自定义好一个漂亮的控件,为什么,那么ContentPresenter的作用是什么,写不写的区别是什么ContentPresenter是WPF中一个非常重要的控件,它的作用是显示控件的内容。以下是ContentPre......