首页 > 其他分享 >LLM大模型: Maskformer/Mask2Former语义分割原理详解

LLM大模型: Maskformer/Mask2Former语义分割原理详解

时间:2024-10-30 17:31:39浏览次数:7  
标签:Maskformer image attention mask LLM decoder query pixel Mask2Former

  1、自动驾驶、机器人、电商、监控等行业都涉及到image的sematic segmentation,传统的方式:per-pixel classification,每个像素点都要分类;如果进一步做 instance-level segmentation,可能还要改network architure后重新训练,很麻烦。FAIR在2021年10月份的时候发表了论文:Per-Pixel Classification is Not All You Need for Semantic Segmentation,单看标题就知道他们做Semantic Segmentation不再是Per-Pixel Classification,这帮人又是怎么做的了?原论文给的名称:per-mask classification,图示图下:

  

  • 图的左边:per pixel做分类,一共有K类,那么每个pixel都要计算这K类的概率,一共有k*h*w个数值;
  • 图的右边:per mask做分类,一共有K类,加上背景就是K+1类;既然是每个mask都做分类,那么mask有多少了?这里用N表示mask的数量!从图示看,每个mask都是image的一部分区域,然后对这部分区域做prediction,看看属于K+1中的哪类!比如图像中的第一个mask预测后就是building类, 第二个mask预测后就是sky类,最后一个mask啥也不是!

  整个过程思路很简单,所以现在最核心的问题来了:这里的mask是怎么划分的?比如上图有100个mask,这100个mask之间是怎么划分地盘的? 换句话说,mask是怎么精准对image中不同object描边的?

  2、原论文中网络架构如下图所示:

  

   从颜色来看,就可以直观的分成三部分,用原论文的描述如下:

  • pixel - level module:
    • A backbone that extracts low resolution features from an image.  从原始image中提取低分辨率的feather F;
    • A pixel decoder that gradually upsamples low-resolution features from the output of the backbone to generate high-resolution per-pixel embeddings.  然后通过decoder把低分辨率的feather转成高分辨率的per-pixel embeddings;这一步的结果就是:每个pixel都得到一个embedding representation,用来表示每个pixel的特征,所以这里得到的三维矩阵:C_epsilon * H * W  
  • transformer module:
    • And finally a Transformer decoder that operates on image features to process object queries. transformer融合image的特征F和N个queries(就是上图的mask)
  • segmentation module:
    • 上路classification loss:经过transformer decoder后,再经过一个MLP进行空间转换,得到C_epsilon * N,也就是每个query/mask的embedding representation,用于描述每个query/mask的特征。然后经过softmax得到了每个query/mask属于哪个object class的概率,所以这里得到的就是 N * (K+1)的二维矩阵
    • 下路binary mask loss:上路得到了每个query/mask的C_epsilon * N特征,又从pixel decoder得到了每个pixel的特征 C_epsilon * H * W ,因为pixel肯定属于某个query/mask,所以把这两相乘,就到了 N * H* W;因为 H * W 代表的就是每个pixel,所以这个三维矩阵表示的是pixel属于某个query/mask的概率
  •  上面就是train过程,接下来就是inference了,也就是最右边的模块: N * (K + 1) 表示query/mask属于某个object类别的概率分布,N * H* W 表示pixel属于某个query/mask的概率分布,这两个相乘,把N去掉,不就得到了 K * H * W了么?H * W是每个pixel像素的位置,K是类别,这个 K * H * W 不就表示每个pixel的K类概率分布了么?是不是感觉这个思路很巧妙了:把N个query/mask作为中间变量latent,巧妙地通过两个矩阵相乘消掉N,得到每个pixel的K类概率分布,这个思路是不是和bayes很像啊!用数学表达式:P(K|H,W) = P(K|N) * P(N|H, W)

  

  现在的问题来了:为什么不直接计算pixel的K类概率分布,而是要通过中间变量mask/query来中转?

  • LLM的fine-tune都知道吧!其中一种方式是Lora微调,就是把大matrix分解成两个小matrix相乘,来减少计算量和存储空间,这里的N作用类似:如果K比较大,直接求P(K|H, W)可能计算量比较大,所以这里分成两个小矩阵相乘的形式减少计算量!
  • 对于pixel的类别,如果在object内部,那么肯定就属于该object啦,这个很容易区分,但还是有些pos比较难区分:轮廓边缘,也就是准确描边:这上面的pixel肯定不是非黑即白的,需要有个概率分布描述所属类别比较合乎业务逻辑
  • query/mask N动态可调整,如果image的instance增加,可适当增加N来涵盖所有的instance;query/mask N在一定程度上可理解为instance的个数上限,每个query/mask对应一个instance,增加了模型的可解释性,也利于理解
  • query/mask N在transformer中和image的feather做cross attention,可捕捉全局特征,理解局部之间的关系,比如transformer的decoder可能会识别出一个query与“行人”类别的关联,并注意到这个query在图像中的位置与另一个表示“车辆”的query有空间上的联系。这种全局的视角使得模型能够推断出,尽管行人的脚部在局部看起来与道路相似,但从整体上看,这个区域应该是属于“行人”的一部分,而不是道路

  3、 maskformer发布后仅仅过了大半年,mask2former接着发布了,又做了哪些改进了?老规矩,先上原论文的图:

  

       原论文介绍如下:Mask2Former overview. Mask2Former adopts the same meta architecture as MaskFormer [14] with a backbone, a pixel decoder and a Transformer decoder. We propose a new Transformer decoder with masked attention instead of the standard cross-attention (Section 3.2.1). To deal with small objects, we propose an efficient way of utilizing high-resolution features from a pixel decoder by feeding one scale of the multi-scale feature to one Transformer decoder layer at a time (Section 3.2.2). In addition, we switch the order of self and cross-attention (i.e., our masked attention), make query features learnable, and remove dropout to make computation more effective (Section 3.2.3). Note that positional embeddings and predictions from intermediate Transformer decoder layers are omitted in this figure for readability. 从架构上看,mask2dormer和maskformer是一样的,没啥区别;最大的改进是使用了masked attention替代了标注你的cross attention。

  (1)Compared to the cross-attention used in a standard Transformer decoder which attends to all locations in an image, our masked attention leads to faster convergence and improved performance。maskformer中Image feather F进入transformer的decoder后,和queries做cross atttention,用论文作者的数据是训练一张图片至少需要32GB的显存(One limitation of training universal architectures is the large memory consumption due to high-resolution mask prediction, making them less accessible than the more memory-friendly specialized architectures [6, 24]. For example, MaskFormer [14] can only fit a single image in a GPU with 32G memory),所以这里必须要改。在mask2former中,用的就是mask attention了!先来回顾一下N query/mask的作用:在1代中,N query/mask需要和image feather F做cross attenion,用来确认N query/mask的K+1类的概率分布。其实在一张image中,真正的object能有多少了?图片中很大一部分都是backgroud,这部分pixel参与cross attetion合理么?这里的思路就可以借鉴一下perceiver resampler 了:根据query做downsample,把query对应的object特征提取出来,其他的特征一概不要,这里首当其冲要去掉的就是backgroud啦!那么问题又来了:怎么精确地找到backgroud了?或者说怎么找到正确的mask来屏蔽backgroud了?

   这种情况是监督学习,就是有大量标注好训练样本的,所以换个角度理解:训练过程中mask的位置肯定是不可能人为标注的,那就只能让模型自己通过Back proporgation学习了!标准的cross-attention是这么干的,如下图:

    

   masked attention是这么干的:transformer decoder的每一层都要加上mask的值,这个值来自上一层的transformer decoder layer,最早的Mask0来自X0,然后经过transformer decoder layer逐层生成后续的mask

    

   (2)原论文提出的第二个改进:Second, we use multi-scale high-resolution features which help the model to segment small objects/regions; 这里的multi-scale是通过pixel-decoder体现的。pixel decoder的输入是image feather,经过decoder后生成多个不同scale的high-resolution来捕捉不同层级的信息。比如低分辨率捕捉image的整体全局信息,诸如大的框架、色调、object位置等宏观信息;高分辨率捕捉的是细节,比如毛发、皮肤、衣着、五官等,所以原文说的是high-resolution features which help the model to segment small objects/regions

   (3)原论文提出的第三个改进:Third, we propose optimization improvements such as switching the order of self and cross-attention, making query features learnable, and removing dropout; all of which improve performance without additional compute;

  Finally, we save 3× training memory without affecting the performance by calculating mask loss on few randomly sampled points. These improvements not only boost the model performance, but also make training significantly easier, making universal architectures more accessible to users with limited compute

  

参考:

1、https://arxiv.org/pdf/2112.01527  Masked-attention Mask Transformer for Universal Image Segmentation    

2、https://github.com/facebookresearch/Mask2Former  

3、https://arxiv.org/pdf/2107.06278   Per-Pixel Classification is Not All You Need for Semantic Segmentation

4、https://www.bilibili.com/video/BV1EA22YnEY1?spm_id_from=333.788.videopod.episodes&vd_source=241a5bcb1c13e6828e519dd1f78f35b2&p=2

标签:Maskformer,image,attention,mask,LLM,decoder,query,pixel,Mask2Former
From: https://www.cnblogs.com/theseventhson/p/18513038

相关文章

  • LLM论文研读: GraphRAG的替代者LightRAG
    1. 背景最近有一个很火的开源项目LightRAG,Github6.4K+星※,北邮和港大联合出品,是一款微软GraphRAG的优秀替代者,因此本qiang~得了空闲,读读论文、跑跑源码,遂有了这篇文章。2. LightRAG框架2.1 已有RAG系统的局限性1)许多系统仅依赖于平面数据表示(如纯文本),限制了根据文本中......
  • 终于有了!!!基于Langgraph使用本地LLM搭建agent!!!
    需求Langchain是使用闭源LLM实现agent搭建的,Langgraph官网给的例子是基于Claude,其他一些agent例子也是基于OPENAI的,但是对于很多私有化场景,使用本地LLM搭建agent是非常重要的。但是网上并没有相关的教程,捣鼓了两天,捣鼓出来Ollama+Langgraph实现的基于本地LLM的agent搭建模......
  • 2025秋招LLM大模型多模态面试题(十三)- rag(检索增强生成)技术
    1.基本概念检索增强LLM(RetrievalAugmentedLLM),简单来说,就是给LLM提供外部数据库,对于用户问题(Query),通过一些信息检索(InformationRetrieval,IR)的技术,先从外部数据库中检索出和用户问题相关的信息,然后让LLM结合这些相关信息来生成结果。下图是一个检......
  • AI大模型(LLMs)五大热点研究方向分享!
    近年来,人工智能大模型(LLMs)的研究不断深入,衍生出了多个热门方向,聚焦提升模型的性能、适应性与应用场景,推动了技术的突破与革新。今天为大家梳理一下AI顶会上的五大热门研究方向,希望为那些专注大模型方向的研究者带来一些灵感和参考。Part.01检索增强生成(RAG)大模型虽然在生......
  • 清华:细粒度强化学习优化LLM工具使用
    ......
  • 人大:优化工具文档提升LLM工具使用
    ......
  • 全面解释人工智能LLM模型的真实工作原理(完结)
    前一篇:《全面解释人工智能LLM模型的真实工作原理(三)》序言:本节作为整篇的收官之作,自然少不了与当今最先进的AI模型相呼应。这里我们将简单介绍全球首家推动人工智能生成人类语言的公司——OpenAI的GPT模型的基本原理。如果你也希望为人类的发展做出贡献,并投身于AI行业,这无疑是一......
  • 大模型LLM:为什么简单的乘法ChatGPT会算错?
    首先“心算”三位整数乘法不管对人类还是对模型来说都不简单的。如果使用CoT的方式就类似于“笔算”,如果使用编程的方式就类似于人拿着计算器算。我将问题更精确一点地表述为“模型如何在心算多位整数乘法上接近或超过人的水平?”这个问题困扰了我很久,简单乘法是推理能力的......
  • 全面解释人工智能LLM模型的真实工作原理(三)
    前一篇:《全面解释人工智能LLM模型的真实工作原理(二)》序言:前面两节中,我们介绍了大语言模型的设计图和实现了一个能够生成自然语言的神经网络。这正是现代先进人工智能语言模型的雏形。不过,目前市面上的语言模型远比我们设计的这个复杂得多。那么,它们到底复杂在什么地方?本节将为你......
  • 使用Spring AI和LLM生成Java测试代码
    背景     AIDocumentLibraryChat项目已扩展至生成测试代码(Java代码已通过测试)。该项目可为公开的Github项目生成测试代码。只需提供要测试的类的网址,该类就会被加载、分析导入,项目中的依赖类也会被加载。这样,LLM就有机会在为测试生成模拟时考虑导入的源类。可以提供te......