论文背景:
本文由何凯明先生主笔,在原先的对比学习模型---如SIMCLR中,需要大量的负样本以供系统学习其特征分布,但是在多数场景下,样本空间中往往负样本不足或过大,比如在一个百万用户量级的推荐系统中,若将与目标用户不像关系的用户视为负样本,则时间计算度会非常大,基本无法完成运算,亦或当在不平衡数据集下,如异常数据集中,系统也无法学习到良好的特征表示。
论文的目标与成果:
本文提出了MOCO(Momentum Contrast)动量学习模型,通过维护一个队列存储负样本,可以有效提升计算效率,增大样本的使用效率。具有如下三个优势
- 小批量高效训练:MoCo相比于其他对比学习方法如SimCLR,不依赖大批量训练,因此在计算资源有限的情况下仍能取得优秀的表现。
- 负样本队列:通过动量队列,MoCo可以保留和使用大量的负样本,这有助于提高模型的学习能力和泛化能力。
- 动量编码器:动量编码器使得负样本的表示更加稳定,从而提高了模型的对比学习效果。
模型架构与核心思想
在已知对比学习的逻辑---将每张图片增强生成两张图片A,图片B--的基础之上,我们修改图片生成的逻辑,生成图片A使用查询编码器(Query Encoder),生成图片B使用键编码器(Key Encoder)
- 查询编码器(Query Encoder):把一张增强后的猫的图片输入查询编码器,得到它的特征表示向量。
- 键编码器(Key Encoder):把另一张增强后的猫的图片输入键编码器,得到它的特征表示向量。
对学习过程中,对每个图片都做此处理,将处理后的键向量存入队列,生成的查询向量将队列中的样本视为负样本进行对比学习,
一个完整的例子:
假设你有以下动物图片:
- 一张猫的图片(原图)
- 一张狗的图片(原图)
- 一张鸟的图片(原图)
在MoCo中,流程如下:
-
增强操作:
- 对猫的图片应用两次不同的增强操作,比如一次裁剪,另一次颜色调整,生成两张增强后的猫的图片。
- 同时对狗和鸟的图片也进行类似的增强操作。
-
编码过程:
- 将增强后的猫图片1通过查询编码器,得到一个特征向量A。
- 将增强后的猫图片2通过键编码器,得到另一个特征向量B。
- 将增强后的狗和鸟的图片分别通过键编码器,得到它们的特征向量C和D。
-
对比学习:
- MoCo希望最大化猫图片1(A)和猫图片2(B)的相似性,同时最小化猫的特征A与狗(C)和鸟(D)特征的相似性。
- 模型会根据这个目标调整查询编码器的参数,学习到猫的特征表示应当与自己的增强版本保持一致,但与其他动物保持不同。
-
负样本队列:
狗的特征C和鸟的特征D会被存储在一个负样本队列中,供以后处理其他图像时使用。随着训练的进行,负样本队列会越来越丰富,帮助模型更好地区分不同类别的图像
通过此处理,我们可以将负样本稳定化存储在队列中,而不必每一次都重新计算生成,大大减少计算量,同时也让负样本队列保持相对稳定,提升系统的稳定性,便于系统学习特征。
MoCo的网络架构:
-
查询编码器(Query Encoder):
- 定义:查询编码器用于将查询图像(即输入图像)转换为特征表示向量。该编码器通常是一个标准的卷积神经网络(CNN),比如常用的ResNet-50。
- 作用:查询编码器的作用是将增强后的图像输入进行编码,生成用于对比学习的特征表示。它的输出是一个高维的特征向量。
- 更新机制:查询编码器的参数通过标准的反向传播算法进行更新(梯度下降),以最小化损失函数。
-
键编码器(Key Encoder):
- 定义:键编码器与查询编码器的结构相同,通常也是ResNet-50,但它的更新方式不同。键编码器的参数更新通过动量更新机制完成。
- 作用:键编码器用于编码正样本(即查询图像的另一个增强版本),并生成其特征表示。与查询编码器不同,键编码器的参数通过动量机制慢慢更新,以保持其输出的稳定性。
- 动量更新机制:键编码器的参数不是直接通过反向传播更新,而是通过如下公式进行动量更新: θk=m⋅θk+(1−m)⋅θq\theta_k = m \cdot \theta_k + (1 - m) \cdot \theta_qθk=m⋅θk+(1−m)⋅θq 其中,θk\theta_kθk 是键编码器的参数,θq\theta_qθq 是查询编码器的参数,mmm 是动量系数(通常设置为接近1,例如0.999)。动量更新使得键编码器的参数更新较慢,从而生成更加稳定的特征表示。
-
负样本队列(Negative Sample Queue):
- 定义:MoCo引入了一个固定大小的队列,用于存储负样本的特征表示。每次训练时,队列中存储的负样本特征用于对比学习,随着新的负样本加入,旧的负样本会被移除。
- 作用:负样本队列的作用是解决负样本不足的问题。通过维持一个包含之前批次的负样本特征的队列,模型不需要依赖大批量训练就可以获得足够的负样本进行学习。
- 队列更新:每当键编码器生成新的特征表示时,这些表示会被添加到队列中,并将队列中的最旧的特征移除。队列始终保持固定的大小。
详细工作流程:
-
图像增强:对于每个输入图像,MoCo会生成两个不同的增强版本。一个增强版本作为查询图像,另一个增强版本作为正样本。
-
查询编码器处理:查询图像通过查询编码器生成一个特征向量 qqq。
-
键编码器处理:正样本图像通过键编码器生成另一个特征向量 k+k^+k+。
-
对比学习:MoCo通过InfoNCE损失函数来最小化查询图像特征 qqq 和正样本特征 k+k^+k+ 之间的距离,同时最大化 qqq 与队列中负样本特征之间的距离。
-
负样本队列:负样本队列保存了之前批次的负样本特征。每次新生成的负样本会加入队列,旧的负样本会被替换。
动量更新机制的优点:
- 特征表示的稳定性:键编码器的动量更新机制使得其参数更新较慢,能够生成更稳定的负样本特征,避免了模型训练中特征表示的剧烈波动。
- 减少负样本依赖:通过维护负样本队列,MoCo可以避免大批量训练的需求,因为队列提供了足够多的负样本用于对比学习。
总结:
MoCo的网络架构由查询编码器、键编码器和负样本队列组成,辅以动量更新机制,确保负样本特征表示的稳定性和丰富性,从而提升对比学习的效果。通过这种设计,MoCo在保持计算效率的同时,解决了小批量训练中的负样本不足问题。