首页 > 其他分享 >[论文阅读] Domain generalization by learning and removing domain-specific features

[论文阅读] Domain generalization by learning and removing domain-specific features

时间:2024-04-03 19:44:41浏览次数:26  
标签:Domain right removing domain 领域 分类器 源域 mathcal left

1 Introduction

最近的研究发现,DNNs倾向于以与人类不同的方式学习决策规则 [17, 21, 16]。例如,在基于ImageNet的图像分类任务中,卷积神经网络(CNNs)倾向于学习局部纹理以区分对象,而我们人类则可能使用全局对象形状的知识作为线索。DNNs学到的特征可能只属于特定的领域,对其他领域不具有泛化性。例如,在现实世界的照片中,属于同一类别的对象具有相似的纹理,但在素描中,对象只由线条绘制,并不包含纹理信息 [27]。对于一个使用纹理来区分照片中对象的CNN,当应用于素描时,可以预期性能不佳。这种情况需要学习跨领域不变特征而不是学习特定领域的特征的DNNs。

一种典型的领域泛化方法是学习跨领域不变的表示[18, 30, 42, 3, 11, 14, 45, 31, 35]。这种方法基于这样的假设:每个领域都有其特定于该领域的特征,而所有领域共享领域不变的特征。这些方法并没有明确地告知深度神经网络应该有效地去除特定于领域的特征。相反,只希望通过实现学习域不变特征的最终目标来消除它们。我们主动去除特定于领域的特征,引导cnn学习领域不变特征进行分类。

在第一步中,每个领域特定分类器被设计为有效地从一个源域中学习领域特定特征。具体地,一个领域特定分类器被设计为在一个特定的源域内区分不同类别的图像。同时,这个分类器要求在任何其他源域中无法区分不同类别的图像。因此,在这种设计下,每个源域对应一个领域特定分类器。

在第二步中,编码器-解码器网络将输入图像映射到一个新的图像空间,在这个空间中,利用特定于领域的分类器从输入图像中去除上面学习到的特定于领域的特征。与第一步不同的是,这里的每个特定于域的分类器都无法区分对应源域中不同类之间的映射图像。与原始输入图像相比,映射图像预计包含更少的领域特定特征。然后将域不变分类器附加到编码器-解码器网络中,并使用映射的图像进行训练。通过这种设计,编码器-解码器网络可以主动去除特定于领域的特征,并且可以更好地引导域不变分类器学习域不变特征。经过训练后,编码器-解码器网络和领域不变分类器将用于对未见过的目标域进行分类。

值得注意的是,我们的框架与基于数据增强的领域泛化方法[43, 34, 46, 7]不同。我们的框架旨在从输入图像中去除领域特定特征,而基于数据增强的方法则生成具有新领域特定特征的各种图像。此外,我们的框架仅将输入图像映射到一个新的图像空间,并不对其进行增强以扩大训练数据集。

2 Proposed framework

2 提出的框架

假设我们有 \(N\) 个源域 \(\mathcal{D}_s=\left\{D_s^1, D_s^2, \ldots, D_s^N\right\}\),它们遵循不同的分布。对于每个域(数据集),\(D_s^i=\left\{\left(\mathbf{x}_j^i, y_j^i\right)\right\}_{j=1}^{n_i^s}\),其中 \(n_i\) 是 \(D_s^i\) 中样本的数量,\(\left(\mathbf{x}_j^i, y_j^i\right)\) 是第 \(i\) 个域中第 \(j\) 个样本的数据-标签对。根据文献,我们假设所有源域和目标域共享相同的标签空间。领域泛化的目标是利用这些源域 \(\mathcal{D}_s\) 来为未见的目标域 \(D_t\) 学习一个模型。

我们的工作受到最近的研究[32]的启发,该研究使用了一个“lens”网络(即图像到图像的转换网络),在自监督学习任务中从输入图像中移除“shortcut”(CNN可以快速学习的低级视觉特征,如水印和色彩偏差)。与此不同的是,我们的工作着重于从输入图像中移除领域特定特征,用于领域泛化任务。我们使用了一个类似于“lens”网络的编码器-解码器网络,但我们设计了一种不同的方法来利用编码器-解码器网络去除领域特定特征。在本节中,我们详细说明了我们的框架,并为我们的框架提供了理论分析。图1概述了整个框架。

image-20240403183101174

图1:LRDG提出的框架概述。 (a) 使用分类损失\(L_C\)在源域\(D_s^i\)上训练领域特定分类器\(F_i\),并使用剩余源域上的不确定性损失\(L_U\)。训练完成后,所有领域特定分类器的权重被冻结。(b) 使用重建损失\(L_R\)和通过领域特定分类器的不确定性损失\(L_U\)训练编码器-解码器网络\(M\)。同时,使用映射图像上的分类损失\(L_C\)训练领域不变分类器\(F\)。(c) 在测试阶段,编码器-解码器网络\(M\)和领域不变分类器\(F\)被用于目标域\(D_t\)上的分类。

2.1 Learning domain-specific features

我们的框架从训练 \(N\) 个单独的领域特定分类器 \(\mathcal{F}_S=\left\{F_1, F_2, \ldots, F_N\right\}\) 开始,其中分类器 \(F_i\) 被设计为仅使用源域 \(D_s^i\) 中的领域特定特征来区分图像。领域特定分类器 \(\mathcal{F}_S\) 不应使用领域不变特征作为线索。换句话说,\(F_i\) 预计能够有效地在 \(D_s^i\) 内区分不同类别的图像,但是对于任何其他域内的不同类别的图像,\(F_i\) 应该难以区分。除了 \(D_s^i\) 外的域被用来最大化分类的不确定性或者对 \(F_i\) 的分类增加对抗性以增加分类的难度。\(F_i\) 在除 \(D_s^i\) 外的域上的分类性能应该类似于随机猜测。

具体来说,分类器 \(F_i\) 通过最小化分类损失 \(\mathcal{L}_C^{F_S}\) 在 \(D_s^i\) 上进行训练,

\[\underset{\theta_i}{\arg \min } \mathbb{E}_{D_s^i \sim \mathcal{D}_s}\left[\mathbb{E}_{\left(\mathbf{x}_j^i, y_j^i\right) \sim D_s^i}\left[L_C\left(F_i\left(\mathbf{x}_j^i ; \theta_i\right), y_j^i\right)\right]\right], \tag{1} \]

并且最大化剩余域 \(\left\{D_s^1, \ldots, D_s^{i-1}, D_s^{i+1}, \ldots, D_s^N\right\}\) 上的不确定性损失 \(\mathcal{L}_U^{F_S}\),

\[\underset{\theta_i}{\arg \max } \mathbb{E}_{D_s^k \sim \mathcal{D}_s, k \neq i}\left[\mathbb{E}_{\left(\mathbf{x}_j^k, y_j^k\right) \sim D_s^k}\left[L_U\left(F_i\left(\mathbf{x}_j^k ; \theta_i\right)\right)\right]\right] \tag{2} \]

其中 \(\theta_i\) 表示分类器 \(F_i\) 的参数。\(L_C\) 和 \(L_U\) 分别是分类损失函数和不确定性损失函数。我们使用交叉熵损失作为分类损失。对于不确定性损失,由于我们的目标是使预测类似于随机猜测,因此我们使用熵损失。

\[L_U\left(F_i\left(\mathbf{x}_j^k ; \theta_i\right)\right)=-\sum_{l=1}^C p\left(y=l \mid F_i\left(\mathbf{x}_j^k ; \theta_i\right)\right) \log p\left(y=l \mid F_i\left(\mathbf{x}_j^k ; \theta_i\right)\right), \tag{3} \]

其中 \(C\) 是类别数量,\(p\left(y=l \mid F_i\left(\mathbf{x}_j^k ; \theta_i\right)\right)\) 表示 \(\mathbf{x}_j^k\) 属于类别 \(l\) 的概率。

最小可能损失[32]是熵损失的一种替代方法。分类器首先预测图像并获取所有类别的概率。概率最低的类别称为最不可能的类别。将该图像标记为该类别。然后我们训练分类器预测最不可能的类别。最小可能损失为

\[L_U\left(F_i\left(\mathbf{x}_j^k ; \theta_i\right)\right)=L_C\left(F_i\left(\mathbf{x}_j^k ; \theta_i\right), \hat{y}_j^k\right) \text {, where } \hat{y}_j^k=\underset{y}{\arg \min } p\left(y \mid F_i\left(\mathbf{x}_j^k ; \theta_i\right)\right) . \tag{4} \]

然而,实验表明熵损失可以更好地实现分类的随机性,因此我们将熵损失作为默认的不确定性损失。 训练完成后,我们冻结这些领域特定分类器 \(\mathcal{F}_S\) 的参数 \(\theta\),并使用这些分类器来学习领域不变特征。

2.2 Removing domain-specific features

为了去除领域特定分类器学习到的领域特定特征,我们利用一个编码器-解码器网络 \(M\) 将图像映射到一个新的图像空间 \(\mathcal{Z}\)。输出图像被馈送到领域特定分类器 \(\mathcal{F}_S\) 和一个新的领域不变分类器 \(F\) 中。

与领域特定分类器 \(F_i\) 的训练不同,其中源域 \(D_s^i\) 被用于最小化分类损失相反,在这一步中,源域 \(D_s^i\) 被用于最大化不确定性 \(\operatorname{loss} \mathcal{L}_U^M\),

在特定领域分类器\(F_i\)的训练过程中,源域\(D_s^i\)用于最小化分类损失的不同。但是在当前这一步中,源域 \(D_s^i\) 被用于最大化不确定性 \(\operatorname{loss} \mathcal{L}_U^M\),

\[\underset{\theta_M}{\arg \max } \mathbb{E}_{D_s^i \sim \mathcal{D}_s}\left[\mathbb{E}_{\left(\mathbf{x}_j^i, y_j^i\right) \sim D_s^i}\left[L_U\left(F_i\left(M\left(\mathbf{x}_j^i ; \theta_M\right) ; \theta_i\right)\right)\right]\right] . \tag{5} \]

领域特定分类器 \(F_i\) 的参数 \(\theta_i\) 被冻结,而编码器-解码器网络 \(M\) 的参数 \(\theta_M\) 被训练。最大化不确定性损失迫使输出图像 \(\mathbf{z}_i=M\left(\mathbf{x}_i\right)\) 包含比输入图像更少的领域特定特征。通过这样做,编码器-解码器网络可以从输入图像 \(\mathbf{x}\) 中去除领域特定特征,并在输出图像 \(\mathbf{z}\) 中保留领域不变特征。

为了保持输入和输出图像的整体相似性,我们为编码器-解码器网络添加了一个重建损失 \(\mathcal{L}_R^M\),

\[\underset{\theta_M}{\arg \min } \mathbb{E}_{D_s^i \sim \mathcal{D}_s}\left[\mathbb{E}_{\left(\mathbf{x}_j^i, y_j^i\right) \sim D_s^i}\left[L_R\left(M\left(\mathbf{x}_j^i ; \theta_M\right), \mathbf{x}_j^i\right)\right]\right], \tag{6} \]

其中 \(L_R\) 是重建损失函数。我们使用像素级 \(l_2\) 损失作为默认的重建损失,因为它简单并且具有相当良好的性能。也可以使用其他重建损失,比如像素级 \(l_1\) 损失和感知损失[24]。详细讨论请参见补充材料。

接着,我们通过在所有源域的输出图像上最小化分类损失 \(\mathcal{L}_C^{F M}\) 来训练领域不变分类器 \(F\)。

\[\underset{\theta_M, \theta_F}{\arg \min } \mathbb{E}_{D_s^i \sim \mathcal{D}_s}\left[\mathbb{E}_{\left(\mathbf{x}_j^i, y_j^i\right) \sim D_s^i}\left[L_C\left(F\left(M\left(\mathbf{x}_j^i ; \theta_M\right) ; \theta_F\right), y_j^i\right)\right]\right], \tag{7} \]

其中 \(\theta_F\) 是领域不变分类器 \(F\) 的参数。这个分类损失 \(\mathcal{L}_C^{F M}\) 也更新了编码器-解码器网络,以防止由于不确定性损失而使编码器-解码器网络丢失领域不变特征。如果难以区分领域特定特征和领域不变特征,则不确定性损失也有可能去除领域不变特征。

总体而言,在训练领域特定分类器时,我们优化

\[\mathcal{L}_1=\mathcal{L}_C^{F_S}+\lambda_1 \mathcal{L}_U^{F_S} \tag{8} \]

而在学习领域不变特征时,我们优化

\[\mathcal{L}_2=\mathcal{L}_C^{F M}+\lambda_2 \mathcal{L}_U^M+\lambda_3 \mathcal{L}_R^M, \tag{9} \]

其中 \(\lambda_1, \lambda_2\) 和 \(\lambda_3\) 是控制这些损失相对权重的超参数。

为了方便起见,我们将编码器-解码器网络 \(M\) 和领域不变分类器 \(F\) 统称为领域不变模型。在测试阶段,领域不变模型被用于目标域 \(D_t\) 上的分类。

2.3 Explanation of LRDG with respect to existing theory

我们首先介绍领域泛化的generalization risk bound [2],然后进一步解释我们的框架相对于这个界的有效性。

理论上,对于一个领域,相应的任务被定义为一个确定性的真实标签函数 \(f\),其中 \(f: \mathcal{X} \rightarrow \mathcal{Y}\)。这里 \(\mathcal{X}\) 和 \(\mathcal{Y}\) 分别是输入空间和标签空间。我们将候选假设的空间表示为 \(\mathcal{H}\),其中一个假设 \(h: \mathcal{X} \rightarrow \mathcal{Y}\)。假设 \(h\) 在领域 \(\mathcal{D}\) 上的风险被定义为

\[\mathcal{R}[h]=\mathbb{E}_{x \sim \mathcal{D}}[\mathcal{L}(h(x)-f(x))], \tag{10} \]

这里 \(\mathcal{L}: \mathcal{Y} \times \mathcal{Y} \rightarrow \mathcal{R}_{+}\) 用来度量假设和真实标签函数之间的差异。

根据 [2],对于源域 \(\left\{\mathcal{D}_s^1, \mathcal{D}_s^2, \ldots, \mathcal{D}_s^N\right\}\),我们将源域的凸包 \(\Lambda_S\) 定义为一组混合源分布:\(\Lambda_S=\left\{\overline{\mathcal{D}}: \overline{\mathcal{D}}(\cdot)=\sum_{i=1}^N \pi_i \mathcal{D}_s^i(\cdot), 0 \leq \pi_i \leq 1, \sum_{i=1}^N \pi_i=1\right\}\)。我们还将 \(\overline{\mathcal{D}}_t \in \Lambda_S\) 定义为最接近目标域 \(\mathcal{D}_t\) 的域。\(\overline{\mathcal{D}}_t\) 给出为 \(\arg \min _{\pi_1, \ldots, \pi_N} d_{\mathcal{H}}\left[\mathcal{D}_t, \sum_{i=1}^N \pi_i \mathcal{D}_s^i\right]\),其中 \(d_{\mathcal{H}}[\cdot, \cdot]\) 是 \(\mathcal{H}\)-散度 [25],用于量化两个域的分布差异。我们针对目标域 \(\mathcal{D}_t\)使用以下generalization risk bound [2]。

Theorem 1 (Generalization risk bound [2])

在前述设置下,对于任何域 \(\mathcal{D}_t\) 和任意假设 \(h \in \mathcal{H}\),风险 \(\mathcal{R}_t[h]\) 满足以下不等式:

\[\mathcal{R}_t[h] \leq \sum_{i=1}^N \pi_i \mathcal{R}_s^i[h]+\frac{\gamma+\epsilon}{2}+\lambda_\pi, \]

其中 \(\gamma=d_{\mathcal{H}}\left[\mathcal{D}_t, \overline{\mathcal{D}}_t\right]\),\(\epsilon=\sup _{i, j \in[N]} d_{\mathcal{H}}\left[\mathcal{D}_s^i, \mathcal{D}_s^j\right]\),\(\lambda_\pi\) 是一些假设 \(h \in \mathcal{H}\) 在 \(\mathcal{D}_t\) 和 \(\overline{\mathcal{D}}_t\) 上达到的风险之和的最小值。\(\gamma\) 衡量了源域和目标域之间的分布差异。\(\epsilon\) 是源域之间的最大两两 \(\mathcal{H}\)-散度。

定理 1 表明目标域的上界取决于 \(\gamma\) 和 \(\epsilon\)。我们展示了我们的框架可以降低给定领域泛化任务的这个generalization risk bound的值。回想一下,我们的编码器-解码器网络将输入图像映射到一个新的图像空间。我们将映射后的源域表示为 \(\left\{\widehat{\mathcal{D}}_s^1, \widehat{\mathcal{D}}_s^2, \ldots, \widehat{\mathcal{D}}_s^N\right\}\),将映射后的目标域表示为 \(\widehat{\mathcal{D}}_t\)。通过使用领域特定分类器,许多源域的领域特定特征被移除,并且映射后的源域的特征倾向于更加领域不变。因此,映射后的源域 \(\left\{\widehat{\mathcal{D}}_s^1, \widehat{\mathcal{D}}_s^2, \ldots, \widehat{\mathcal{D}}_s^N\right\}\) 的分布差异可能会比原始源域小,即 \(d_{\mathcal{H}}\left[\widehat{\mathcal{D}}_s^i, \widehat{\mathcal{D}}_s^j\right] \leq d_{\mathcal{H}}\left[\mathcal{D}_s^i, \mathcal{D}_s^j\right]\),这表明方程式 11 中的 \(\epsilon\) 可能会减小。在为每个源域移除领域特定特征后,映射后的目标域 \(\widehat{\mathcal{D}}_t\) 会更接近于映射后的源域,因此我们的框架也可能会减小方程式 11 中的 \(\gamma\)。关于定理 1,这些变化提供了对所提出的框架有效性的原理性解释和保证。我们将在实验部分(第 3.3 节)展示这些变化。

标签:Domain,right,removing,domain,领域,分类器,源域,mathcal,left
From: https://www.cnblogs.com/EIPsilly/p/18113402

相关文章

  • 中断子系统(一)IRQ Domain
    前言在现代计算机系统中,中断模块的硬件越来越复杂,有可能有多个中断控制器(InterruptController,IC)之间进行级联从而拓展可以管理的中断源数量。这就会产生几个问题,每个IC上都连接着多个设备,IC会给irqline连接的每一个设备分配一个硬件中断请求号(HWinterruptnumber,hwirq),不同......
  • 论文解读(UDA-GCN)《Unsupervised Domain Adaptive Graph Convolutional Networks》
    Note:[wechat:Y466551|可加勿骚扰,付费咨询]论文信息论文标题:UnsupervisedDomainAdaptiveGraphConvolutionalNetworks论文作者:论文来源:2020aRxiv论文地址:download 论文代码:download视屏讲解:click1-摘要图卷积网络(GCNs)在许多与图相关的分析任务中都取得了令人印......
  • A LARGE LANGUAGE MODEL EVALUATION BENCHMARK AND BASELINE FOR CHINESE PUBLIC SECU
    本文是LLM系列文章,针对《CPSDBENCH:ALARGELANGUAGEMODELEVALUATIONBENCHMARKANDBASELINEFORCHINESEPUBLICSECURITYDOMAIN》的翻译。CPSDBENCH:中国公共安全领域的大型语言模型评估基准和基线摘要1引言2相关工作3方法4结果与分析5结论摘要大......
  • 关于衍射光波导设计中的K阈(k-domain)分析的一些学习
     对于衍射光波导的设计来说,不能简单利用几何光的方法对光线的传播路径进行描述。因此可以基于K空间波矢的矢量运算来进行描述。 在阈值分析中,衍射光波导的光线传播遵循二个引导条件,分别为全内反射条件和引导模式条件。如图所示。             ......
  • Windows 系统中进行一些域管理操作 net group /domain 命令 参数
    Windows系统中进行一些域管理操作。以下是一些常见的用法和参数:列出所有域用户组:bashCopyCodenetgroup/domain查看特定用户组的成员:bashCopyCodenetgroup"GroupName"/domain添加用户到指定用户组:bashCopyCodenetgroup"GroupName"UserName/add/domain......
  • Codeforces Round 656 (Div. 3) F. Removing Leaves
    ProblemDescription给出一棵\(n\)个节点的无根树。你可以进行以下操作:选择\(k\)个共同父节点的叶子节点,将\(k\)个节点和与父节点相连的边删去。求最大操作次数。Input第一行输入一个整数\(t\)\((1\let\le2\times10^4)\),表示测试组数。接下来每组测试数据第......
  • CF351D - Jeff and Removing Periods 题解
    首先做一点显然的转化:在进行第一次操作之后,可以将相同的数排在一起,这样一次就能删掉一种数。如果一开始就能删光一种数的话,那么次数就是区间颜色数,否则就是区间颜色数\(+1\)。所以现在原问题变成了两个问题:求区间内不同颜色数,判断区间内是否有某种颜色满足其出现位置构成等差数......
  • SciTech-Mathmatics-FourierSeries: Time Domain and Frequency Domain
    TimeDomainandFrequencyDomainFrequencydomain:measuredbySpectrumAnalysiszerTellsushowproperties(amplitudes)changeoverfrequencies:TimeDomain:measuredbyOscilloscopeTellsushowproperties(suchasAmplitude(Power),Phase,andsoon)......
  • 领域驱动设计(Domain-Driven Design,简称DDD)【简介 个人学习笔记】
    找到了第1篇资料:领域驱动设计详解:是什么、为什么、怎么做?-知乎找到了第2篇资料:领域驱动架构(DDD)建模中的模型到底是什么?-知乎找到了第3篇资料:一文看懂DDD领域驱动设计-知乎找到了第4篇资料:什么是DDD(领域驱动设计)?这是我见过最容易理解的...找到了第5篇资料:领......
  • CF1385F Removing Leaves 题解
    解题思路简单贪心,优先选择叶子节点最多的,这样能够保证一定能把所有能删的都删了。因为要建一个可删除的图,所以我们可以使用set来存边,不然就需要维护一堆东西……那么我们肯定是从有叶子节点的点向父亲更新的,那么我们每次选择叶子节点最多的点,然后删除\(k\)个叶子,判断一下删......