首页 > 其他分享 >Incrementer:Transformer for Class-Incremental Semantic Segmentation with Knowledge Distillation Focu

Incrementer:Transformer for Class-Incremental Semantic Segmentation with Knowledge Distillation Focu

时间:2023-06-03 16:56:00浏览次数:41  
标签:Transformer Semantic 蒸馏 新类 模型 class token 类别 Class

摘要

目前已有的连续语义分割方法通常基于卷积神经网络,需要添加额外的卷积层来分辨新类别,且在蒸馏特征时没有对属于旧类别/新类别的区域加以区分。为此,作者提出了基于Transformer的网络incrementer,在学习新类别时只需要往decoder中加入对应的token。同时,作者还提出了对于旧类别区域的蒸馏以及类别解混策略。最终模型在VOC和ADE两个数据集上都达到了SOTA。

方法

截屏2023-06-01 16.17.07.png

Incrementer Structure

模型整体框架就是一个Vision Transformer的结构,分为编码器和解码器两部分。特别的,为了在不添加额外网络结构的前提下实现增量过程中对于新类别的学习,作者收到Segmenter的启发,为每个类别(包括背景)分配了一个可学习的class token,之后将这些class token与encoder的输出进行拼接,送入decoder以得到对应的visual embeddings以及class embeddings。最后,每个类别的mask通过计算class embedding与visual embedding的相似度得到。在增量学习的过程中,为了防止相似度分数的计算受到新类别的影响,作者对于mask的生成使用了余弦相似度,首先对embedding进行l2归一化,之后通过下式得到mask:\(S'=\gamma \overline{e}_c\overline{e}_v^T,S'\in \mathbb{R}^{(M+1)\times N}\)。将其插值回原图大小就得到了最终的分割结果。基于上述框架,可以很方便地通过添加新类的class token实现类增量学习,在每个增量步骤t将旧类的class token冻结,与新类的class token还有visual feature一起送入decoder得到分割图。为了缓解背景偏移,作者使用了PLOP的伪标签策略,最终的损失计算如下:
截屏2023-06-02 22.31.58.png
其中\(\omega_i\)的作用是减轻模型对于新类的过拟合。

Knowledge Distillation Focusing on Old Class

现有的知识蒸馏方法不对特征图中属于不同类别的区域加以区分,旧模型会将新类别对应的区域认作背景,直接蒸馏会限制新模型的可塑性。为此,作者提出了FOD,只对visual feature中不属于新类别的区域进行蒸馏,即根据gt选取不包含新类别像素的visual tokens进行蒸馏。因为在生成mask时使用了cosine similarity,这里同样使用cosine similarity以保持相似度测量的一致性,避免使用l2-distance等硬知识蒸馏损失而牺牲可塑性。蒸馏损失计算如下:
截屏2023-06-02 23.01.18.png
其中:
截屏2023-06-02 23.02.02.png
对于伪标签中标记为背景的像素(不确定是旧类or新类)乘以系数\(\frac{C^{0:t-1}}{C^{0:t}}\)进行调节。如此只是在局部进行蒸馏,作者还对新旧模型decoder输出的旧类别class tokens进行蒸馏,从而约束模型对于全局特征的预测一致,公式如下(\(\beta_i\)的含义与上面的\(\alpha_i\)类似):
截屏2023-06-02 23.23.02.png
最终总的蒸馏损失就是\(L_{FOD}=L_{Vis_FOD}+L_{Cls_FOD}\)。

Class Deconfusion Strategy

在每次学习一个类且学习的步骤有多个时,模型对于新类的过拟合现象较为严重,同时,如果新类与某些旧类较为相似,模型则很容易将其混淆。为此,作者提出了CDS,首先通过权重\(\omega_i=\lambda \sqrt{\frac{C^t}{C^{0:t}}}\)降低对于新类的损失;其次,为了减轻模型对于新类与旧类的混淆问题,作者通过模型生成一个mask \(B^t\)以区分新类和旧类:
截屏2023-06-03 15.49.04.png
同时使用预测结果\(S^t\in \mathbb{R}^{|C^{0:t}|\times H\times W}\)沿通道维度相加,生成\(S_O^t\)以及\(S_N^t\):
截屏2023-06-03 16.22.22.png
最后通过\(B^t\)监督\(S_O\)和\(S_N\)的生成,损失为dice loss:
截屏2023-06-03 16.34.28.png
其中\(\tilde{B}_i^t\)是对\(B_i^t\)取反得到的。
总的损失为\(L=L_{WCE}+L_{FOD}+L_{BM}\)。

实验

截屏2023-06-03 16.36.53.png
截屏2023-06-03 16.37.04.png
截屏2023-06-03 16.37.18.png
截屏2023-06-03 16.37.40.png
直接把10-1刷到70多了,ade的四个setting都刷到了40+,非常恐怖。
截屏2023-06-03 16.38.51.png

标签:Transformer,Semantic,蒸馏,新类,模型,class,token,类别,Class
From: https://www.cnblogs.com/lipoicyclic/p/17454205.html

相关文章

  • Hugging News #0602: Transformers Agents 介绍、大语言模型排行榜发布!
    每一周,我们的同事都会向社区的成员们发布一些关于HuggingFace相关的更新,包括我们的产品和平台更新、社区活动、学习资源和内容更新、开源库和模型更新等,我们将其称之为「HuggingNews」,本期HuggingNews有哪些有趣的消息,快来看看吧!重磅更新TransformersAgents发布,通过自......
  • Vue——属性指令、style和class、条件渲染、列表渲染、事件处理、数据双向绑定、过滤
    vm对象<body><divid="app"><h1>{{name}}</h1><button@click="handleClick">点我</button></div></body><script>//1写在data或method中的属性或方法,从vm中直接可以点出来//2method的函数中,如......
  • springboot项目rabbitmq消费者消费json格式的String,出现无限循环抛出No method found
    转:springboot项目rabbitmq消费者消费json格式的String,出现无限循环抛出Nomethodfoundforclass[B     ......
  • idea - 文件后面显示 0%classes,0% lines covered -解决
    这是因为不小心按了 可以ctrl+Alt+F6 选择"nocoverage "按钮即可......
  • transformers入门使用
    transformers入门使用HuggingFace是一个开源社区,它提供了先进的NLP模型,数据集,以及其他工具。模型:https://huggingface.co/models数据集:https://huggingface.co/datasets主要的模型自然回归:GPT2,Trasnformer-XL,XLNet自编码:BERT,ALBERT,ROBERTa,ELECTRASto......
  • Neo4j导入RDF文件之neosemantics安装
    一、背景众所周知,Neo4j本身是不支持导入RDF文件的。但我们可以借助neosemantics来进行导入。neosemantics简称n10s,4.0之前的把呢不能叫semantics。现在官网也无法访问4.0之前版本的相关文档了。推荐大家使用4.0以上的。虽然我用的是3.5.5二、安装插件1、下载......
  • 【十六】classmethod()函数(1)
    【十六】classmethod()函数(1)【1】作用将方法转换为类方法。类方法将类作为隐式第一个参数接收,就像实例方法接收实例一样。类方法可以在类(如c.f())上调用,也可以在实例(如c().f())上调用。除了类之外,实例被忽略。如果为派生类调用了类方法,则派生类对象将作为隐含的第一个......
  • 【二十】issubclass()函数 -- 检查类型(1)
    【二十】issubclass()函数--检查类型(1)【1】作用Python提供了如下两个函数来检查类型:issubclass(cls,class_or_tuple):检查cls是否为后一个类或元组包含的多个类中任意类的子类。isinstance(obj,class_or_tuple):检查obj是否为后一个类或元组包含的多个类中......
  • java.lang.ClassNotFoundException: weblogic.utils.NestedException
    我单元测试的时候报这种错误Causedby:java.lang.ClassNotFoundException:weblogic.utils.NestedException atjava.net.URLClassLoader$1.run(URLClassLoader.java:202) atjava.security.AccessController.doPrivileged(NativeMethod) atjava.net.URLClassLoader.findC......
  • class loader
    http://docs.oracle.com/javase/7/docs/technotes/tools/findingclasses.htmlhttps://docs.oracle.com/javase/tutorial/ext/basics/load.htmlhttp://docs.oracle.com/javase/7/docs/technotes/guides/security/spec/security-spec.doc5.htmlhttps://docs.oracle.com/javas......