Comment: Accepted to NeurIPS 2023
对齐提示:用于zero-shot泛化的测试时提示分布对齐
摘要
CLIP等视觉语言模型的zero-shot泛化已经引领它们在下游任务中使用提示学习。先前的工作已经表明使用熵最小化进行测试时提示调优,调整文本提示适应未见过的领域。尽管这样的方法非常高效,但是它们忽略了在不可见领域中性能下降的关键因素——分布偏移。
本文使用提示调优来将分布外测试数据与源域数据进行对齐来解决以上问题。本文测试时使用单个测试样本通过最小化特征分布偏移来调整多模态提示,从而弥合测试域中的差距。
与领域适应的基准进行评估,本文方法比现有提示学习方法提高了,比基准MaPLe提高了3.08%。在未知类别的跨数据集泛化中,与现有的最好的方法相比,本文方法在10个数据集上得到一致提高。
Introduction
深度神经网络( deep neural networks,DNNs )在大多视觉识别任务中的表现超过了人类。然而,当测试数据与训练数据相同分布时,它们的表现通常是令人印象深刻的。在大多数实际应用中,由于自然变化或传感设备变化等因素,训练和测试数据的分布可能会显著不同。模型在推断过程中对不可见分布偏移的敏感性导致其性能下降。之前的大量工作已经探索了测试时间自适应作为一种机制来克服测试数据中分布偏移的问题。然而,对于受欢迎的基础模型,即在大规模视觉语言数据集上训练的大型DNN,其测试时间自适应探索已微乎其微。
在视觉、语言和音频等多种模态的交汇处出现的基础模型,在众多下游应用中被证明是有效的。在这些模型中,Vision-Language(V-L)模型CLIP已经在大规模图像-文本对上进行预训练,并且可以很好地推广到zero-shot识别任务中。CLIP模型具有并行的视觉和语言编码分支,并利用从这两个分支获得的嵌入向量之间的相似性来对输入图像进行分类。在推理时,使用手工制作的提示,如" A photo of a",作为文本编码器的查询。然而,将CLIP高效地适配到特定的下游任务仍然是一个具有挑战性的问题。对这类模型进行微调会带来失去其固有泛化能力的风险。相反,最近的方法显示了在训练数据上进行提示学习的方法,而不是使用手工提示,使模型能够更好地适应训练数据分布。
现有的提示学习方法部署在训练阶段,根据下游任务的训练数据学习具有代表性的提示。这种传统的方法没有直接处理测试集的分布偏移。最近TPT(Test-time Prompt tuning)利用提示学习的能力,通过动态调整文本提示来使模型适应测试样本,从而进行测试时间自适应。对于图像分类,模型最小化样本的增强视图(具有高置信度)的熵来更新提示。然而,TPT并没有明确地将预训练的CLIP对齐,以了解测试样本分布。
对于V-L基础模型中测试时间自适应,对于高zero-shot泛化在弥合预训练数据与下游测试集之间的分布差距至关重要。因此本文提出了一种使用提示学习的测试时间token分布对齐策略PromptAlign。TPT是仅在文本分支上提示调优,从而实现分布对齐,这样存在结构上的限制。此外,在测试阶段文本分支和视觉分支之间没有发生知识迁移,即两个分支在处理输入数据时独立的。对于给定数据集,由于输入是相同的类标签,因此文本编码器特征将是静态的,因此token的分布对齐只能在视觉分支上进行。考虑到这些限制,为了进一步扩展测试时间自适应提示的强度,本文提出使用多模态提示学习模型( MaPLe) 进行分布对齐。Prompt Align将离线计算的代理源数据集的图像令牌嵌入的均值和方差与测试样本的图像令牌嵌入进行对齐。本文使用token对齐策略扩展TPT,使其能够桥接测试数据(图1a)中的分布偏移。对于每一个输入的测试样本,随机获得增广视图,并将其输入到模型中,以获得token嵌入统计量。在没有明显计算开销的情况下,同时更新CLIP的文本和视觉分支上的提示,以联合最小化特征分布偏移和预测的熵。
本文通过在领域泛化和跨数据集泛化两个具有代表性的基准上评估零样本泛化,展示了Prompt Align的有效性。在领域泛化设置中,本文方法在4个数据集上的基线模型平均提高了3.08 %,与现有最先进的方法相比,具有最高的平均Top - 1准确率。在跨数据集的情况下,本文方法比现有的使用测试时间提示调优的先进方法获得了1.82 %的绝对平均提升,同时在10个数据集中的8个数据集中获得了最好的Top - 1准确率。
本文贡献
(1)在仅给定单个测试样本的情况下,本文引入一种V-L模型的分布对齐策略,以提高测试时间的自适应性。分布感知预训练的CLIP有效地缩小了测试域上的分布差距。据我们所知,这是第一个探索V - L模型在测试时分布对齐的研究。
(2)本文制定了一个分布对齐损失,利用离线计算的源数据统计来促进测试样本token分布与源数据token分布对齐。本文使用多模态提示学习方法结合了令牌分布对齐和熵最小化的优势。
(3)由于CLIP预训练数据没有公开发布,本文研究了ImageNet的统计量作为源分布的可能候选,实证结果表明ImageNet是CLIP等大规模V - L模型的有效代理源数据集。
(4)通过在领域泛化和跨数据集基准测试中的大量实验验证了本文方法Prompt Align。Prompt Align在测试时间上提高了CLIP的泛化性,超越了现有的提示调优方法,取得了目前最好的结果。
Method-Preliminaries
Contrastive Language-Image Pre-training (CLIP)
CLIP由两个并行编码器组成,文本编码器用于将文本输入映射为文本特征向量,图像编码器用于将视觉输入映射为图像特征向量。
Image Encoder
在这项工作中,选择ViT作为图像编码器。
(1)给定输入图像I∈RH×W ×3,包含K个transformer层的图像编码器将图像打成M个固定大小的patch.
(2)得到投影patch的嵌入向量E0∈RM×dv。
(3)嵌入向量和一个可学习的类标记ck一起被输入到图像编码器Vk+1的第(k+1)层,并依次通过以下变换:
(4)为了获得最后的图像特征向量,将最后一个transformer层的类标记ck通过ImageProj投影到共同给特征空间中:
Text Encoder
(1)文本编码器包含K个transformer层,将输入的单词打成若干token,然后映射为单词嵌入W0。然后将单词嵌入向量送到文本编码器的Lk+1层:
最终的文本表示z是通过TextProj与最后一个transformer层输出的token投影到潜在的嵌入空间:
Zero Shot Prediction
对于zero-shot预测,在CLIP的语言分支中引入提示,通过加入与下游任务相关联的每个类名来构成文本输入。然后选择余弦相似度分数最高的类别作为图像的预测标签:
Prompt tuning on downstream tasks
CLIP从数百万个包含噪声的图像文本对中进行训练,学到大量知识。为了有效地提取CLIP学到的丰富的特征,最近的方法保证文本编码器和图像编码器冻结,然后添加额外的可学习提示。这些提示在不毁坏CLIP预训练学习到的特征的情况下修改了模型输入的上下文向量。在图像或文本编码器端添加提示,并学习针对特定任务定制的上下文信息。本文使用了一个最近引入的多模态提示基线MaPLe,它在文本和图像编码器上学习提示token。
具体来说,分别学习V个视觉提示和T个文本提示作为两个分支中可学习向量。图像分支输入:编码器进行处理,得到相应图像特征向量。文本分支类似输入:得到文本特征表示。本文方法使用MaPLe中使用的深度提示,以及transformer块中的文本提示和条件图像提示。我们用p共同表示视觉提示和文本提示。
Test-time prompt tuning
测试时提示调优(TPT)旨在利用CLIP丰富的知识以zero-shot的方式提高其泛化能力。TPT可以看作一种方法,为模型提供一个为每个单独的测试样本定制的上下文向量,以便准确回忆出CLIP中包含的知识。
在推理过程中,对于给定的测试样本Xtest,随机生成若干个增广试图。将增广试图按照置信度从高到低排序(熵值从低到高),使用选择过滤器进行过滤。
然后使用过滤后的增强视图预测的平均熵,使用以下目标函数以无监督方式对提示进行更新:
其中代表模型产生增强视图经过选择过滤器后产生类概率的均值。
Method-PromptAlign
单模态测试时调优(TPT)在推理时通过最小化交叉熵更新文本提示,但它没有明显处理测试集中出现的分布偏移的问题,因此它是次优的。解决这个问题的一种方法是将测试样本带入源域来对齐源域和目标域的分布。但是TPT仅在使用静态标签的文本分支中更新提示,这对齐token分布时存在架构上的限制。因此,token分布对齐只能在文本分支上对齐。因此,本文提出使用多模态提示学习模型MaPLe来处理源域和目标域中分布偏移的问题。
给定一个测试样本Xtest,随机得到多个增强视图后进行置信度过滤器,并通过具有深度提示的视觉编码器。在视觉编码器的每一层,我们计算测试样本的均值和方差与代理数据集(模拟源域数据分布)的均值和方差之间的token对齐。最终目标是结合熵和对齐损失来更新给定测试样本的提示。
Proxy source dataset
为了计算源域数据集上的token嵌入统计,我们需要CLIP模型的预训练数据集。但是CLIP是在超过4亿的图像文本对上预训练,无法公开获取。无论如何,先前的工作表明LAION400M可以作为训练数据集来达到CLIP的性能,因此,本文使用LAION400M的子集作为训练集的代理数据集。此外,CLIP还进行大量的调整,以在ImageNet上实现zero-shot性能。因此,本文使用Imagenet作为计算token分布均值和方差的代理数据集,其统计值是离线计算的,并在测试时直接使用。
Token Distribution Alignment via multi-modal prompting
(1)给定一个测试样本,利用一组增广操作集合H生成测试样本的Nk个随机视图。
(2)在CLIP模型的视觉编码器的每个transfromer层的输出端中,计算测试样本的Nk个增强视图的token嵌入向量的均值和方差统计量。类似,源域数据统计是以离线方式预先计算的。其中(T)代表测试样本分布,(D)代表源域数据分布。
(3)本文在测试样本中计算对齐的token均值和方差方法:
其中表示视觉编码器第l个transformer层测试样本token向量的均值和方差表示。表示输入x的增强视图在第l层的提示token嵌入向量。
(4)类似,对于图像编码器中的每一层l,预先计算源域样本的统计量:
其中θv代表CLIP预训练模型的视觉编码器参数。
(5)使用L1损失计算测试样本与源数据统计量的均值和方差之间的token分布对齐损失:
(6)将对齐损失Lalign加入到熵损失Lentropy中,得到最终的目标损失Lfinal进而更新提示:
其中β代表超参数,用来控制对齐损失对总目标损失函数的概念。通过更新提示使源域与目标域分布对齐。
Discussion on Lfinal
综上所述,本文的测试时损失包含了熵最小化损失Lentropy和分布对齐目标损失Lalign。Lentropy目标加强了同一测试样本不同试图的预测一致性,从而对测试时间出现的各种情况具有鲁棒性。另一方面,Lalign有效缩短了领域偏移,使测试样本更接近预训练的CLIP分布空间,从而加强CLIP对测试样本的理解。
这两个损失目标的结合满足了CLIP对不同样本变化的鲁棒性,增强了CLIP对潜在测试样本域的理解,以获得更好的泛化性。
Conclusion
本文介绍了一种新的方法PromptAlign,用于增强视觉-语言( Vision-Language,V-L )模型的零样本泛化的测试时间适应性。本文提出的方法通过token分布对齐将测试样本统计量与源数据分布的统计量进行显式对齐来弥合测试样本与源数据分布之间的差距。为了实现这一点,本文结合多模态提示,以促进在测试期间transformer层之间的token分布的对齐。通过大量的实验,Prompt Align在领域泛化和跨数据集评估设置方面表现出优于现有的最新CLIP零样本泛化方法的性能。
标签:Shot,CLIP,Prompting,提示,Align,样本,token,测试,对齐 From: https://blog.csdn.net/m0_54248968/article/details/142099812