首页 > 其他分享 >CLIP损失函数的理解

CLIP损失函数的理解

时间:2023-06-13 17:11:48浏览次数:47  
标签:loss 函数 CLIP embeds text image 损失 logits

  参考资料:

  [一个写的相当好的教程]

  [CLIP huggingface源码:CLIPModel]

  [CLIP huggingface训练例程]

  这篇文章首先展示CLIP损失函数的两种底层实现代码,然后聊一聊自己的理解。

  说实话念硕士的时候没有接触过CLIP这个东西,来实习之后发现这个多模态的模型使用非常广泛,设计理念也是看后惊为天人。加上最近有探究任务研究CLIP,BLIP这些,遂决心把这个模型弄懂。参考资料1已经把CLIP的设计思想,原理,甚至是底层实现给讲清楚了,但是当我读到训练的损失函数那一段的时候还是产生了很大的疑问:作者说有两种方式来计算损失函数,一种较为简单,一种较为复杂。较为复杂的损失函数实现如下:

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()

  其中Cross_entropy也是作者自己实现的,看上去就是logsoftmax加上NLLloss:

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

  较为简单的损失函数的实现则是这样:nn.CrossEntropyLoss()(logits, torch.arange(batch_size))

  作者在下面进行了分析,我看完分析之后觉得... ... 作者的语气好像是在说这种较为简单的损失函数是有误的,在数据集中有同一张图片的多个相似caption的时候会明显犯错。那么,较为复杂的损失函数就是正确的了。以上是Tutorial里作者的实现,较为权威的另一种实现是huggingface团队Transformer库里的源码。由于CLIP模型的高度可定制性,huggingface团队实现了一个基类,也就是CLIPModel部分。并在需要训练的时候把loss设置为forward函数的第一个返回值,我们来看一下他们的实现:

image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)

text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)

# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()

loss = None
if return_loss:
    loss = clip_loss(logits_per_text)

  其中,clip_loss的实现如下:

# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
    return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(similarity)
    image_loss = contrastive_loss(similarity.t())
    return (caption_loss + image_loss) / 2.0

  一开始的归一化比较好理解,logit_scale是一个超参数也好理解。最难理解的就是logits_per_text和logits_per_image这两个互为转置的矩阵。写这篇文章的时候我只能说自己弄懂了7分,原论文中有这么一段话:While standard image models jointly train an image feature extractor and a linear classifier to predict some label, CLIP jointly trains an image encoder and a text encoder to predict the correct pairings of a batch of (image, text) training examples. 即CLIP是学习(image, text)图文对之间的正确匹配的。这个正确匹配有两个对称的方面:1)对于每一个caption,和它吻合的图片得到label 1,和它不吻合的图片得到label 0。(这个对应于caption_loss)2)对于每一个image,和它吻合的caption得到label 1,和它不吻合的图片得到label 0。(这个对应于image_loss)而将两个loss相加除以2,得到的损失函数就同时考虑了两个方面了。如果一个模型在这两个方面都做得好,那么大概率是能够成功学习到correct pairings of a batch of (image, text) 的。

标签:loss,函数,CLIP,embeds,text,image,损失,logits
From: https://www.cnblogs.com/chester-cs/p/17478159.html

相关文章

  • C++ 虚函数与动态绑定
    多态与动态绑定为了实现C++的多态,C++使用了动态绑定技术,该技术的核心是虚函数表(简称虚表)。类的虚函数表每个包含了虚函数的类都包含一个虚表,一个子类如果继承了包含虚函数的父类,那么这个类也拥有自己的虚表,例如classA{public:virtualvoidvfunc1();virtualv......
  • 如何实现一个函数重载的功能
    函数重载将函数接收到的不同参数,进行不同处理。importcreateOverLoadfrom'./funReload.js'constgetUsers=createOverLoad()getUsers.addImpl(()=>{console.log('查询所有用户')})getUsers.addImpl('string',(name)=>{console.log('......
  • 常见m2eclipse安装错误及其解决方法
    最近学习maven,发现一些安装问题,从网上找了一些解决方法---------------------------------------------------------------------------------错误一:eclipse3.6.1安装maven插件失败解决方法:--------------------------------------------------------------------------------......
  • eclipse 3.6.1 安装maven插件失败的解决办法
      一、eclipse3.6.1下载地址[eclipse-jee-helios-SR1-win32.zip]http://www.eclipse.org/downloads/packages/eclipse-ide-java-ee-developers/heliossr1二、插件地址1、gef插件地址:http://download.eclipse.org/tools/gef/updates/interim/2、subclipse插件地址:http......
  • Eclipse环境搭建全集(个人使用的环境,Eclipse+SVN+Maven+JbossTo...
    评:1.JDK的配置去官网下载JDK,需要注意的是JDK32位,Eclipse也必须是32位.64位JDK对应Eclipse64位.1.安装JDK,安装过程中最好自定义安装目录等信息,如我们选择安装目录为E:\software\Java\jdk1.6.0_34.2.安装完成后,我的电脑点击属性,选择高级选项卡点击环境变量.3.在系统变量......
  • 【技术积累】JavaSciprt中的函数【一】
    什么是函数?如何声明函数?JavaScript中的函数是一段可重复使用的代码块,它可以接受输入并返回输出。在JavaScript中,函数是一种特殊的对象,因此可以将其存储在变量中,将其作为参数传递给其他函数,并从其他函数中返回。在JavaScript中,声明函数有两种方式:函数声明和函数表达式。1.函数......
  • DQL-聚合函数
           ......
  • C/C++学习(10)关于数组、内联函数、虚函数的错题集锦
    1、顺序存储方式不仅用于存储线性结构,还可以用于存放非线性结构,如完全二叉树是属于非线性结构,但其最佳存储方式是顺序存储方式。 2、数组名有两重属性:1)数据结构的一个对象(数据结构为当前数组),在java中数组就是一个对象。2)某些情况下自动退化成指向第一个元素的常量指针。 3、有两......
  • 14.拷贝构造函数、静态、友元和预编译头
    拷贝构造函数静态友元预编译头拷贝构造函数eg:Playeer.h代码:#pragmaonceclassPlayeer{private:intnum;char*name;public:Playeer(intx,constchar*name);~Playeer();voiddisplay();//输出结果voidsetX(intx);//......
  • C语言,函数包含失败问题
    1.头文件包含顺序出错导致头文件中的函数无法使用eg:在主函数中调用support.h中的strcat()函数失败,但是明明已经包含了strcat()函数的头文件进来;编译器还是提示“Undefinedsysbolsupport(refreedfromxxx.o)”.以下函数只是简单举例,请不要直接拿来编译main中,先调用了includ......