总览
一般的机器学任务是,给定一个输入,预测其对应的的标签、值或一组值。这样的任务使用像是交叉熵损失 Cross-Entropy Loss 和均方误差损失 Mean Square Error Loss 就行。度量学习 Metric Learning 则不一样,它的目标是预测不同输入的相对距离。例如,衡量两张人脸的相似程度,或是推理两句话表达含义的相似度。
衡量相似度的通常做法是,让模型使用各种输入 生成代表特征的 embedding 向量,然后用各个向量间的距离(例如欧氏距离、余弦相似度)衡量这些输入的相似度。
排名损失 Ranking Loss
排名损失 Ranking Loss 函数就是用于解决 Metric Learning 问题的损失函数。根据应用场景的不同,Ranking Loss 拥有不同的名称,例如对比损失 Contrastive Loss、间隔损失 Margin Loss、铰链损失 Hinge Loss 或三元损失 Triplet Loss。
常用损失函数主要有对比损失 Contrastive Loss 和三元组损失 Triplet Loss。介绍之前,先给一些符号定义:
- \(f(I)\),使用模型将图片转换为 embedding
- \(\left\Vert f(I_p)-f(I_q) \right\Vert_2\),取两个 embedding 的欧氏距离
对比损失 Contrastive Loss
对比损失 Contrastive Loss,也可称为 成对损失 Pairwise Loss。
\[L=y \left\Vert f(I_p)-f(I_q) \right\Vert_2 +(1-y)\max \left\{ 0, m-\left\Vert f(I_p)-f(I_q) \right\Vert_2 \right\} \]- \(I_p\) \(I_q\),两张图片
- \(y\),当 \(I_p\) \(I_q\) 被标记为相似时为 1,被标记为不相似时为 0
- \(m\),不相似阈值
可以这样解释:若 \(I_p\) \(I_q\) 相似,直接用欧氏距离作为损失;若 \(I_p\) \(I_q\) 不相似,欧式距离大于 \(m\) 时进行惩罚,小于 \(m\) 时不做处理。
三元组损失 Triplet Loss
\[L=\max \left\{ 0, \left\Vert f(I_a)-f(I_q) \right\Vert_2 - \left\Vert f(I_a)-f(I_n) \right\Vert_2 + m \right\} \]- \(I_a\),锚点图片
- \(I_p\),正样本,与 \(I_a\) 相似的图片
- \(I_n\),负样本,与 \(I_a\) 不相似的图片
- \(m\),期望的正负样本差距
这个损失函数很直观,最小化正样本欧氏距离、最大化负样本欧氏距离。\(m\) 决定了正负样本期望的与锚点的欧氏距离之差,值越大则模型对正负样本分得越开。
Triplet Loss 将学习任务从 Contrastive Loss 的 “使得正负样本差距更大” 变为了 “使得正负样本相对于锚点的差距更大”。实验结果证明这样做拥有更好的性能。
mini-batch
使用排名损失会带来的问题:
- 样本爆炸。我们不可能训练所有可能的正负样本组合
- 简单样本陷阱。随着网络的训练,越来越多的正负样本组合会得到极小的 loss,应当有选择地为网络提供困难的样本组合
- 计算浪费。使用三元组损失时,一次需要计算 3 个 embedding,但这些 embedding 只会用来算一次损失
为了解决这些问题,使用 mini-batch 是一个很好的思路。例如一个 mini-batch 里有 64 个样本,不仅算出的 64 个 embedding 可以排列组合利用多次,还能在选定锚点样本和正样本后,选择一个对模型来说相对分辨困难的负样本。具体的样本选择策略很值得探讨,但本文不再深究。可以查阅 pytorch-metric-learning
文档中的 Miners 一节(https://kevinmusgrave.github.io/pytorch-metric-learning/miners/)。
参考来源
- EORA.ai,“How to Build an Image Search Engine to Find Similar Images”,https://medium.com/@eora/how-to-build-an-image-search-engine-to-find-similar-images-c6ec429b9a27
- “Understanding Ranking Loss, Contrastive Loss, Margin Loss, Triplet Loss, Hinge Loss and all those confusing names”,https://gombru.github.io/2019/04/03/ranking_loss/
- 一个用 MNIST 数据集进行可视化对比实验的项目。能从各个图中直观看到网络结构、排名损失和负样本策略带来的影响。https://github.com/adambielski/siamese-triplet