首页 > 其他分享 >[论文速览] Hard Patches Mining for Masked Image Modeling

[论文速览] Hard Patches Mining for Masked Image Modeling

时间:2023-07-12 11:01:04浏览次数:35  
标签:Mining mathbf Patches Image mask 损失 patch mathcal mathrm

Pre

title: Hard Patches Mining for Masked Image Modeling
accepted: CVPR 2023
paper: https://arxiv.org/abs/2304.05919
code: https://github.com/Haochen-Wang409/HPM
ref: CVPR 2023 | 挖掘困难样本的 MIM 框架: Hard Patches Mining for Masked Image Modeling

关键词:MIM, self-supervised, 自监督掩码学习
阅读理由:CVPR, 方法似乎简单有效

Idea

先让模型根据重建损失的大小自己生成困难mask,再像传统方法那样训练模型去预测masked patches

Motivation&Solution

m: NLP的词已经高度语义化,而CV里图片则存在空间信息的冗余,各种自监督掩码学习方法的性能强烈依赖于手工定义的掩码策略

s: 提出一种新的困难样本挖掘策略,让模型自主地掩码困难样本

Background

对比学习能学到视角不变的特征。
SimMIM发现大mask核在不同mask比例下更加鲁棒

Method(Model)

Overview

HPM 包含一个学生模型和一个教师模型,它们共享网络结构,包含 encoder \(f_{\theta}\),图像重建 decoder \(d_{\phi}\),损失预测 decoder \(d_{\psi}\)。教师模型的参数是由学生模型指数平滑更新而来的。

每次迭代,一张图片首先打成 patch,并经过教师模型,得到每个 patch 预测的重建损失。进而,基于该预测,产生当前的二元 mask \(\mathbf{M}\) 用于 MIM 任务,0表示遮掩。损失函数包含两项:

\[\mathcal{L} = \mathcal{L}_{\mathrm{rec}} + \mathcal{L}_{\mathrm{pred}}, \tag{2} \]

其中 \(\mathcal{L}_{\mathrm{rec}}\) 表示重建损失,是标准的 MIM 损失;而 \(\mathcal{L}_{\mathrm{pred}}\) 表示的是重建损失预测损失。

\[\mathcal{L}_{\mathrm{rec}} = \mathcal{M} \left( d_{\phi_s}(f_{\theta_s}(\mathbf{x} \odot \mathbf{M})), \mathcal{T}(\mathbf{x} \odot (1 - \mathbf{M})) \right), \tag{3} \]

其中 \(\mathbf{M} \in \{0,1\}^N\) 表示产生的二元 mask, \(\odot\) 表示 element-wise dot product,因此 \(\mathbf{x} \odot \mathbf{M}\) 表示可见的 patches。 \(\mathcal{T}(\cdot)\) 表示 target 的 transformation,例如 MAE 中就是一个恒等映射,而 BEiT 中则是将图像转化为离散的 token。 \(\mathcal{M}(\cdot, \cdot)\) 表示某种度量,如 MAE 中用的 $\ell_2 $ 距离,SimMIM 中用的 smooth \(\ell_1\) 距离。

重建损失L_pred

Absolute loss.
一种最直观的方法就是直接最小化真实重建 loss \(\mathcal{L}_{\mathrm{rec}}\) 和预测的重建 loss 之间的 MSE,即

\[\mathcal{L}_{\mathrm{pred}} = \left( d_{\psi_s}(f_{\theta_s} (\mathbf{x} \odot \mathbf{M})) - \mathcal{L}_{\mathrm{rec}} \right)^2 \odot (1 - \mathbf{M}), \tag{4} \]

其中 \(d_{\psi_s}\) 表示的是学生模型的 loss predictor,而 $\mathcal{L}_{\mathrm{rec}} $ 需截断梯度。

然而,这里的目标是确定图像中的困难样本,需要 patch 之间重建损失的相对大小,因此 MSE 并不是最合适的选择,因为 \(\mathcal{L}_{\mathrm{rec}}\) 将随着训练的进行而减少, \(\mathcal{L}_{\mathrm{pred}}\) 也会变小,但这不代表它学到了东西,为此作者提出了一种基于二元交叉熵的相对损失。

Relative loss.
给定一张含有N个 patch 的图片,其真实的重建损失为 \(\mathcal{L}_{\mathrm{rec}} \in \mathbb{R}^N\) ,目的是预测这N个 patch 之间重建损失的相对大小,即 \(\texttt{argsort}(\mathcal{L}_{\mathrm{rec}})\) ,但 \(\texttt{argsort}(\cdot)\) 不可导,因此作者将其转换为dense relation comparison问题,预测patch两两之间的大小关系:

\[\begin{aligned} \mathcal{L}_{\mathrm{pred}} = &-\sum_{i=1}^N \sum_{j=1 \atop j\neq i}^N \mathbb{I}^{+}_{ij} \log \left( \sigma(\hat{\mathcal{L}}^s_i - \hat{\mathcal{L}}^s_j) \right) \\ &-\sum_{i=1}^N \sum_{j=1 \atop j\neq i}^N \mathbb{I}^{-}_{ij} \log \left( 1 - \sigma(\hat{\mathcal{L}}^s_i - \hat{\mathcal{L}}^s_j) \right), \end{aligned} \tag{5} \]

其中 \(\hat{\mathcal{L}}^s = d_{\psi_s}(f_{\theta_s}(\mathbf{x} \odot \mathbf{M})) \in \mathbb{R}^N\) 是学生模型输出的损失预测值,而 \(i, j=1,2,\dots,N\) 是 patch indexes。 \(\sigma(z) = e^z / (e^z + 1)\) 是 \(\texttt{sigmoid}\) 函数。 \(\mathbb{I}^{+}_{ij}\) 和 \(\mathbb{I}^{-}_{ij}\) 是两个指示函数,表示 patch i 和 patch j 的真实重建损失大小,定义如下:

\[\mathbb{I}^{+}_{ij} = \left\{ \begin{aligned} &1, &&\mathcal{L}_{\mathrm{rec}}(i) > \mathcal{L}_{\mathrm{rec}}(j) \mathrm{\ and\ } \mathbf{M}_i=\mathbf{M}_j=0, \\ &0, &&\mathrm{otherwise}, \end{aligned} \right. \\ \mathbb{I}^{-}_{ij} = \left\{ \begin{aligned} &1, &&\mathcal{L}_{\mathrm{rec}}(i) < \mathcal{L}_{\mathrm{rec}}(j) \mathrm{\ and\ } \mathbf{M}_i=\mathbf{M}_j=0, \\ &0, &&\mathrm{otherwise}, \end{aligned} \right. \]

其中 \(\mathbf{M}_i=\mathbf{M}_j=0\) 表示对应的patch i, j应当被mask。
对于 \(\mathbb{I}^{+}_{ij}\) ,+表示此时i的损失应当大于j,值为1表示损失需要计算,此时 \(\hat{\mathcal{L}}^s_i - \hat{\mathcal{L}}^s_j\) 以0为界,越大表示预测出来i的损失比j越大,符合target的 \(\mathbf{M}_i=\mathbf{M}_j=0\) ,则损失越小。

以前半部分损失为例,本质上损失通过\(\hat{\mathcal{L}}^s\)之间的关系定义,当 \(\mathbb{I}^{+}_{ij}\) 有值1表示i的真实损失应当大于j,此时如果 \(\hat{\mathcal{L}}^s_i > \hat{\mathcal{L}}^s_j\) 且差值越大,则经过sigmoid函数后值越是比0.5大且接近1,再过log最终会是一个负值,配合最前面的负号,会是一个正值且向0靠近。

Easy-to-Hard Mask Generation

一个自然的想法就是每次迭代过程中,先基于教师模型计算 \(\texttt{argsort}[d_{\psi_t}(f_{\theta_t} (\mathbf{x}))]\) ,然后 mask 掉 top-75% 的 patch。然而,在早期训练阶段,学到的大多是纹理,重建损失与判别性(能决定图像类别的前景主体?)还没有建立起相应的关系。为此作者提出了一种由易到难的掩码生成方式,提供一些合理的提示,引导模型一步一步地重建掩码的硬块。

具体来说,假设 mask ratio 为 \(\gamma\) ,则在 t 次迭代只 mask 掉最大的 \(\alpha_t\gamma N\) 个 patch,剩余 \((1-\alpha_t)\gamma N\) 个需要 mask 的 patch 则随机产生,其中 \(\alpha_t = \alpha_0 + \frac{t}{T}(\alpha_T - \alpha_0).\) 。也就是随着训练的推进,逐渐降低随机mask的比例。

算法1 Pytorch风格的HPM伪代码

Experiment

Training Detail

以 ViT-B/16 为 backbone,预训练 200 epochs

Dataset

ImageNet-1K

Results

Ablation Study

表1 在不同重建目标上的消融研究。第一行是MAE baseline,以自回归地生成RGB像素的形式训练,后面三个以教师模型的特征为target(知识蒸馏),架构都一样(ViT-B/16)。

表5 下游任务上的消融 从表1取了两个预训练模型做下游任务

重建目标的消融。可以看到,不管以什么为重建目标,加入作为\(\mathcal{L}_{\mathrm{pred}}\)额外的损失,并基于此进一步产生更难的mask都能获得性能提升。仅仅引入 \(\mathcal{L}_{\mathrm{pred}}\) 也能够带来性能提升,表明挖掘困难样本的能力本身就能够促使学到更好的特征表示这一点不仅在分类任务上得到体现,下游任务(检测分割)也有相应的体现。

表2 不同mask策略的消融,较大的$\alpha_T$表明代理任务(pretext task)更加困难,但该策略的随机性就会下降。

Esay-to-hard: \(\alpha_t = \alpha_0 + \frac{t}{T}(\alpha_T - \alpha_0).\)

难度大的代理任务确实能够带来性能提升,但保留一定的随机性也是同样必要的。直接掩盖那些预测损失最高的 patch 虽然带来了最难的问题,但图像可判别部分几乎被被掩盖了,意味着 可见的patch 几乎都是背景(见图2)。在没有任何提示的情况下,强迫模型只根据这些背景来重建前景是没有意义的。

表3 不同mask策略的消融。验证在预测的重建损失上使用 argmax(·) 的有效性,argmin(·) 表示每次mask简单的patch。 $\alpha_0 > \alpha_T$ 表示困难patch的使用逐渐增加,是hard-to-easy的方式

进一步地,探究困难的代理任务对于 MIM 是否有帮助。其中, argmin 表示这个任务甚至简单于 random masking,跟 hard-to-easy 一样都会导致性能退化。

表4 预测损失形式上的消融,对比了不加和加上公式4、5的两种损失

MSE 相较于 baseline 能够有提升,但 BCE 是一个更好的选择。

表6 在ImageNet-1K上对比SOTA,一个横线的表示作者实现的版本,两个横线的表示从其他论文抄过来的数据 eff. ep. 表示 Effective pre-training epoch

表6将HPM跟其他方法对比,用于对比的分三类:对比学习、像素回归的MIM、特征蒸馏的MIM

表7 在ADE20k上对比SOTA,两个横线的表示从其他论文抄过来的数据

图4 COCO验证集的可视化,该数据集训练时没见过,右边是预测的重建损失

Conclusion

HPM作为一种即插即用的模块可以无缝接入现有的框架中,性能都能得到提升。MIM的常见问题是线性探测和k-NN分类性能不如对比学习方法。此外HPM由于有个额外的解码器,会有更大的计算开销,比起MAE baseline,在训练ViT-L时会花费1.1倍的时间。将来的方向可以是设计一种更好的损失预测任务,不借助额外辅助解码器。

Critique

思想其实挺简单的,但做了非常多的实验来证明,附录还有一大堆。感觉这类还是看了大概就好,除非要去用它的代码。

Unknown

标签:Mining,mathbf,Patches,Image,mask,损失,patch,mathcal,mathrm
From: https://www.cnblogs.com/Stareven233/p/17469947.html

相关文章

  • el-image 插槽样式补全
    问题el-image有两个插槽:placeholder和error。依照demo使用时,样式会发生偏差:demo:使用:解决办法F12打开demo元素页,发现有如下样式:因此,把类名粘贴到全局样式中即可://el-image插槽样式.image-slot{font-size:14px;display:flex;justify-content:c......
  • Django 使用 ImageKit 进行的ImageField 图像处理
     有图像的话,肯定不知保存,需要改变图像的像素,大小等,这就需要第三方的libpipinstall-Udjango-imagekitpipinstall-UPillow settings.py里面,追加imagekit  ,MEDIA_URL,MEDIA_ROOT model.pyfromdjango.dbimportmodelsfromimagekit.modelsimportImageSp......
  • Unity UGUI的Image(图片)组件的介绍及使用
    UGUI的Image(图片)组件的介绍及使用1.什么是UGUI的Image(图片)组件?UGUI的Image(图片)组件是Unity引擎中的一种UI组件,用于显示2D图像。它提供了一种简单而灵活的方式来在游戏中加载和显示图片。2.为什么要使用UGUI的Image(图片)组件?使用UGUI的Image组件可以方便地在游戏中展示各种图片......
  • java BufferedImage怎么转byte[]?
    一.为什么要将BufferedImage转为byte数组?在传输中,图片是不能直接传的,因此需要把图片变为字节数组,然后传输比较方便。而字节数组变成BufferedImage能够还原图像。参考1:https://blog.csdn.net/weixin_39958559/article/details/114788932参考2:https://blog.csdn.net/itigoitie/......
  • Visual Studio2019 BackgoroundImageLayout属性
    ​BackgroundImageLayout属性值背景图片重复:BackgroundImageLayout属性设置为Tile(默认)背景图片左边显示:BackgroundImageLayout属性设置为None背景图片右边显示:BackgroundImageLayout属性设置为None,同时RightToLeft属性设置为Yes背景图片居中显示:BackgroundImageLayout属性设......
  • OGG-02912 Patch 17030189 is required on your Oracle mining database for trail fo
    Therewillbeascript"prvtlmpg.plb"undergghomedirectory[oracle@OGGR2-1ogg]$ls-lrtprvtlmpg.plb-rw-r-----1oracleoinstall9487May272015prvtlmpg.plb[oracle@OGGR2-1ogg]$pwd/ogg[oracle@OGGR2-1ogg]$Logintothedatabaseand......
  • draw line on image
    cv2.line(image,start_point,end_point,color,thickness)#Pythonprogramtoexplaincv2.line()method#importingcv2importcv2image=cv2.imread(path)start_point=(0,0)end_point=(250,250)color=(0,255,0)thickness=9#Usingcv2.line(......
  • flatpak appimage大小对比
    格式:单应用大小/加上依赖/安装后大小 flatpakAppImageaptgimp127.7MB/797.6MB/366MB164MB vscodium120MB/972MB/335MB82.2MB82.2MBblender383MB/410MB/1187.84MB209.64MB244MB对比了三款常用的桌面软件,看得出f......
  • IOS开发-使用UIImageView加载网络图片
    使用UIImageView加载网络图片可以分为三步1.创建UIImageView实例:UIImageView*imgview=[[UIImageViewalloc]init];imgview.frame=CGRectMake((self.view.frame.size.width-100)/2,(self.view.frame.size.height-100)/2,100,100); 2.下载图片数据:NSUR......
  • IOS开发-UIImageView基本用法
    UIImageView是iOS中用于显示图像(图片、gif、svg等)的视图。它的主要功能有:1.显示图片UIImageView可以通过image属性显示一张UIImage类型的图片。可以是本地图片、从网络下载的图片等。2.设置填充模式可以通过contentMode属性设置图片在UIImageView内的显示和填充模式。内容......