首页 > 其他分享 >对比学习简记

对比学习简记

时间:2023-02-21 20:24:26浏览次数:50  
标签:loss mathbf similarities 样本 学习 简记 device theta 对比

目录

Self-Supervised Learning 的核心思想

Unsupervised Pre-train, Supervised Fine-tune.

img

两大主流方法

  • 基于 Generative 的方法
  • 基于 Contrative 的方法

基于 Generative 的方法主要关注的重建误差,还原原始输入;

基于Contrastive 的方法不要求模型能够重建原始输入,而是希望模型能够在特征空间上对不同的输入进行分辨,判断输入是否相似。

img

实践应用

  • BERT系列:nlp
  • VIT系列:cv
  • data2vec系列:multimodal
  • SimCLR系列:对比学习
  • MoCo系列

Contrastive Representation Learning

对比学习指导原则

  • 构造相似实例和不相似实例
  • 习得一个表示学习模型,使得相似的实例在投影空间中比较接近,而不相似的实例在投影空间中距离比较远

对比学习目标函数

Contrastive Loss

最早的Loss是对比Loss,即同类样本间更相似,最小化同类样本的embedding 距离,最大化非同类embedding的距离

\[\mathcal{L}_\text{cont}(\mathbf{x}_i, \mathbf{x}_j, \theta) = \mathbb{1}[y_i=y_j] \| f_\theta(\mathbf{x}_i) - f_\theta(\mathbf{x}_j) \|^2_2 + \mathbb{1}[y_i\neq y_j]\max(0, \epsilon - \|f_\theta(\mathbf{x}_i) - f_\theta(\mathbf{x}_j)\|_2)^2 \]

Triplet Loss

Triplet loss最小化anchor和正样本间的距离,最大化anchor和负样本间的距离

\[\mathcal{L}_\text{triplet}(\mathbf{x}, \mathbf{x}^+, \mathbf{x}^-) = \sum_{\mathbf{x} \in \mathcal{X}} \max\big( 0, \|f(\mathbf{x}) - f(\mathbf{x}^+)\|^2_2 - \|f(\mathbf{x}) - f(\mathbf{x}^-)\|^2_2 + \epsilon \big) \]

关键:选择合适的负样本,提升模型性能

img

N-pair Loss

Multi-Class N-pair loss generalizes triplet loss to include comparison with multiple negative samples.

Given a \((N + 1)\) tuplet of training samples,\(\{ \mathbf{x}, \mathbf{x}^+, \mathbf{x}^-_1, \dots, \mathbf{x}^-_{N-1} \}\), including one positive and \(N-1\) negative ones, N-pair loss is defined as:

\[\begin{aligned} \mathcal{L}_\text{N-pair}(\mathbf{x}, \mathbf{x}^+, \{\mathbf{x}^-_i\}^{N-1}_{i=1}) &= \log\big(1 + \sum_{i=1}^{N-1} \exp(f(\mathbf{x})^\top f(\mathbf{x}^-_i) - f(\mathbf{x})^\top f(\mathbf{x}^+))\big) \\ &= -\log\frac{\exp(f(\mathbf{x})^\top f(\mathbf{x}^+))}{\exp(f(\mathbf{x})^\top f(\mathbf{x}^+)) + \sum_{i=1}^{N-1} \exp(f(\mathbf{x})^\top f(\mathbf{x}^-_i))} \end{aligned} \]

当负样本数量为1时,等价于多分类softmax。

NCE

把多分类问题转化成二分类,判断正样本和负样本是否为同一类。

\[\begin{aligned} \mathcal{L}_\text{NCE} &= - \frac{1}{N} \sum_{i=1}^N \big[ \log \sigma (\ell_\theta(\mathbf{x}_i)) + \log (1 - \sigma (\ell_\theta(\tilde{\mathbf{x}}_i))) \big] \\ \text{ where }\sigma(\ell) &= \frac{1}{1 + \exp(-\ell)} = \frac{p_\theta}{p_\theta + q} \end{aligned} \]

其中 target sample \(\sim P(\mathbf{x} \vert C=1; \theta) = p_\theta(\mathbf{x})\), noise sample \(\sim P(\tilde{\mathbf{x}} \vert C=0) = q(\tilde{\mathbf{x}})\).

InfoNCE

\[\mathcal{L}_q = - \log\dfrac{\exp(q k_+ / \tau)}{\sum_i \exp(q k_i / \tau)} \]

假设我们忽略\(\tau\),那么infoNCE loss其实就是cross entropy loss。唯一的区别是,在cross entropy loss里,\(k\)

指代的是数据集里类别的数量,而在对比学习InfoNCE loss里,这个\(k\)指的是负样本的数量.

如果温度系数设的越大,logits分布变得越平滑,那么对比损失会对所有的负样本一视同仁,导致模型学习没有轻重。如果温度系数设的过小,则模型会越关注特别困难的负样本,但其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差.

对比学习损失(InfoNCE loss)与交叉熵损失的联系,以及温度系数的作用 - Youngshell的文章 - 知乎 https://zhuanlan.zhihu.com/p/506544456

对比学习关键点

数据增强

对原始数据增加噪音等数据增强,生成正样本。

如SimCLR表明,随机裁剪和随机颜色失真对视觉表示学习非常关键。

大的BatchSize

对依赖In-batchNegative的场景,大的batch size可以提高训练效率,增加模型挑战。

hard Negative Mining

对于有监督情况,可以直接将其它类样本作为负样本;

对于无监督情况,可能会偶然把同类样本作为负样本,导致性能大幅下降。

img

vision

Image Augmentation

裁剪、缩放、加噪、翻转、转换灰度图

常用框架:AutoAugment、RandAugment、PBA、UDA

图像混合

SimCLR

img

img

loss, 每个Batch里面的所有Pair的损失之和取平均:

\[L = \frac{1}{2N}\sum_{k=1}^{N}[l(2k-1,2k)+l(2k,2k-1)] \]

img

CLIP

img

nlp

Text Augmentation

编辑距离、随机替换、删除等

SimCSE

Unsupervised SimCSEimage-20230221200719051

Supervised SimCSE

image-20230221200755790

img

SimCSE2 中改进了两点:

  1. 负样本质量,原本都是同一句话的embedding dropout,但句子长度相同,会导致模型倾向。
  2. batchsize过大,引起性能下降,未解之谜
def unsup_loss(y_pred, lamda=0.05, device="cpu"):
    idxs = torch.arange(0, y_pred.shape[0], device=device)
    y_true = idxs + 1 - idxs % 2 * 2
    similarities = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=2)

    similarities = similarities - torch.eye(y_pred.shape[0], device=device) * 1e12

    similarities = similarities / lamda

    loss = F.cross_entropy(similarities, y_true)
    return torch.mean(loss)


def sup_loss(y_pred, lamda=0.05, device="cpu"):
    row = torch.arange(0, y_pred.shape[0], 3, device=device)
    col = torch.arange(y_pred.shape[0], device=device)
    col = torch.where(col % 3 != 0)[0]
    y_true = torch.arange(0, len(col), 2, device=device)
    similarities = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=2)

    similarities = torch.index_select(similarities, 0, row)
    similarities = torch.index_select(similarities, 1, col)

    similarities = similarities / lamda

    loss = F.cross_entropy(similarities, y_true)
    return torch.mean(loss)

References

【1】对比学习(Contrastive Learning):研究进展精要. https://zhuanlan.zhihu.com/p/367290573

【2】https://lilianweng.github.io/posts/2021-05-31-contrastive/

【3】Self-Supervised Learning 超详细解读 (目录) - 科技猛兽的文章 - 知乎 https://zhuanlan.zhihu.com/p/381354026

【4】对比学习损失(InfoNCE loss)与交叉熵损失的联系,以及温度系数的作用 - Youngshell的文章 - 知乎 https://zhuanlan.zhihu.com/p/506544456

标签:loss,mathbf,similarities,样本,学习,简记,device,theta,对比
From: https://www.cnblogs.com/gongyanzh/p/17142270.html

相关文章

  • docker学习
    1.背景打算装虚拟机,嫌麻烦,想到docker也可以实现,所以在本地部署docker2.docker和虚拟机的区别linux环境安装docker和制作镜像win11环境安装docker和制作镜像......
  • 学习进度
    今天完成了对android入门的学习,用了2小时的时间。我知道了什么是安卓开发,并安装了Androidstudio编译器,知道了什么是线性布局,需要把原有的布局删掉,重新写一个布局,并且了解了......
  • 2023.2.21 我的第一篇博客——软件工程学习心得体会
    今天是我第一次在博客园写博客,本人目前是上海海洋大学软件工程系大二在读,第一篇博客就聊聊我这一年半对软件工程学习的感想吧。编程语言方面,大一学习了C和C++,大二上学期学......
  • 2月21日javaweb学习之MyBatis
    MyBatis是一款优秀的持久层框架,所谓持久层就是负责将数据保存到数据库的那一层代码。(1)MyBatis快速入门,查询user表中所有的数据1.创建user表,添加数据2.创建模块,导入坐标......
  • 学习记录(2.21)
    今天总共学习了7h,感觉是收获颇丰的一天。其中1.5h跟外教激情对线,学习了如何正规且不失优雅的自我介绍,为CET4的口语考试奠定了一部分基础。之后的2.5h在学习和复习数据库的......
  • 机器学习评价指标之回归问题
    1.平均绝对误差:MAE(MeanAbsoluteError)2.均方误差:MSE(MeanSquaredError)3.均方根误差:RMSE(RootMeanSquardError)4.决定系数:R2(R-Square)5.校正决定系数(AdjustedR-......
  • 嵌入与表示学习 embedding & representation learning, embedding & Encoding
     视频:https://www.bilibili.com/video/BV1Cf4y1e7Ht/?spm_id_from=333.788&vd_source=6292df769fba3b00eb2ff1859b99d79e ========================================......
  • pandas vs sql 基本操作对比
    作为一名数据分析师,平常用的最多的工具是SQL(包括MySQL和HiveSQL等)。对于存储在数据库中的数据,自然用SQL提取会比较方便,但有时我们会处理一些文本数据(txt,csv),这个时候......
  • 学习安卓App开发的基本流程
    许多小伙伴想了解学习开发一个安卓系统的App大概需要什么流程,那我们简单看一下吧!第一、开发语言选择。语言其实只是开发实际应用的第一步,安卓开发的首选语言是Kotlin,次选......
  • 瞎聊机器学习——K-均值聚类(K-means)算法
    本文中我们将会聊到一种常用的无监督学习算法——K-means。1、K-means算法的原理K-means算法是一种迭代型的聚类算法,在算法中我们首先要随机确定K个初始点作为质心,然后去计......