首页 > 其他分享 >forward_MDCS

forward_MDCS

时间:2024-11-01 09:46:39浏览次数:3  
标签:loss target dim MDCS torch softmax forward logits

最好的效果 47.35

    def forward(self, output_logits, target, extra_info=None):
        if extra_info is None:
            return self.base_loss(output_logits, target)  # output_logits indicates the final prediction

        loss = 0
        temperature_mean = 1
        temperature = 1
        # Obtain logits from each expert
        epoch = extra_info['epoch']
        num = int(target.shape[0] / 2)

        # expert1_logits = extra_info['logits'][0] + torch.log(torch.pow(self.prior, -0.5) + 1e-9)      #head
        #
        # expert2_logits = extra_info['logits'][1] + torch.log(torch.pow(self.prior, 1) + 1e-9)         #medium
        #
        # expert3_logits = extra_info['logits'][2] + torch.log(torch.pow(self.prior, 2.5) + 1e-9)       #few

        expert1_logits = extra_info['logits'][0] + torch.log(torch.pow(self.prior, 2) + 1e-9)      #head

        expert2_logits = extra_info['logits'][1] + torch.log(torch.pow(self.prior, 2) + 1e-9)         #medium

        expert3_logits = extra_info['logits'][2] + torch.log(torch.pow(self.prior, 2) + 1e-9)       #few



        teacher_expert1_logits = expert1_logits[:num, :]  # view1
        student_expert1_logits = expert1_logits[num:, :]  # view2

        # teacher_expert2_logits = expert2_logits[:num, :]  # view1
        # student_expert2_logits = expert2_logits[num:, :]  # view2
        #
        # teacher_expert3_logits = expert3_logits[:num, :]  # view1
        # student_expert3_logits = expert3_logits[num:, :]  # view2


        teacher_expert1_softmax = F.softmax((teacher_expert1_logits) / temperature, dim=1).detach()
        student_expert1_softmax = F.log_softmax(student_expert1_logits / temperature, dim=1)

        # teacher_expert2_softmax = F.softmax((teacher_expert2_logits) / temperature, dim=1).detach()
        # student_expert2_softmax = F.log_softmax(student_expert2_logits / temperature, dim=1)
        #
        # teacher_expert3_softmax = F.softmax((teacher_expert3_logits) / temperature, dim=1).detach()
        # student_expert3_softmax = F.log_softmax(student_expert3_logits / temperature, dim=1)



        teacher1_max, teacher1_index = torch.max(F.softmax((teacher_expert1_logits), dim=1).detach(), dim=1)
        student1_max, student1_index = torch.max(F.softmax((student_expert1_logits), dim=1).detach(), dim=1)

        # teacher2_max, teacher2_index = torch.max(F.softmax((teacher_expert2_logits), dim=1).detach(), dim=1)
        # student2_max, student2_index = torch.max(F.softmax((student_expert2_logits), dim=1).detach(), dim=1)
        #
        # teacher3_max, teacher3_index = torch.max(F.softmax((teacher_expert3_logits), dim=1).detach(), dim=1)
        # student3_max, student3_index = torch.max(F.softmax((student_expert3_logits), dim=1).detach(), dim=1)


        # distillation
        partial_target = target[:num]
        kl_loss = 0
        if torch.sum((teacher1_index == partial_target)) > 0:
            kl_loss = kl_loss + F.kl_div(student_expert1_softmax[(teacher1_index == partial_target)],
                                         teacher_expert1_softmax[(teacher1_index == partial_target)],
                                         reduction='batchmean') * (temperature ** 2)

        # if torch.sum((teacher2_index == partial_target)) > 0:
        #     kl_loss = kl_loss + F.kl_div(student_expert2_softmax[(teacher2_index == partial_target)],
        #                                  teacher_expert2_softmax[(teacher2_index == partial_target)],
        #                                  reduction='batchmean') * (temperature ** 2)
        #
        # if torch.sum((teacher3_index == partial_target)) > 0:
        #     kl_loss = kl_loss + F.kl_div(student_expert3_softmax[(teacher3_index == partial_target)],
        #                                  teacher_expert3_softmax[(teacher3_index == partial_target)],
        #                                  reduction='batchmean') * (temperature ** 2)

        # loss = loss + dkd_loss(expert1_logits, output_logits, target)
        loss = loss + dkd_loss(expert1_logits, output_logits, target)
        #
        loss = loss + 0.6 * kl_loss * min(extra_info['epoch'] / self.warmup, 1.0)



        # expert 1
        loss += self.base_loss(expert1_logits, target)

        # expert 2
        loss += self.base_loss(expert2_logits, target)

        # expert 3
        loss += self.base_loss(expert3_logits, target)


        return loss

标签:loss,target,dim,MDCS,torch,softmax,forward,logits
From: https://www.cnblogs.com/ZarkY/p/18519423

相关文章

  • Setting up a mobile hotspot on your Samsung Galaxy phone is straightforward
    SettingupamobilehotspotonyourSamsungGalaxyphoneisstraightforward.Herearethesteps:OpenSettings:Swipedownfromthetopofthescreentoopenthenotificationshade,thentapthegearicontoaccessSettings.Connections:TaponConnec......
  • 以pytorch的forward hook为例探究hook机制
    在看pytorch的nn.Module部分的源码的时候,看到了一堆"钩子",也就是hook,然后去研究了一下这是啥玩意。基本概念在深度学习中,hook是一种可以在模型的不同阶段插入自定义代码的机制。通过自定义数据在通过模型的特定层的额外行为,可以用来监控状态,协助调试,获得中间结果。以前向hook......
  • 前向声明Forward Declaration
    在C++中,前向声明(ForwardDeclaration)是一种声明类型(如类、结构体、联合等)而不提供完整定义的方法。前向声明允许你在使用某些类型时避免包含其完整定义,从而减少编译时间、避免循环依赖,并改善代码的可维护性。一、前向声明的基本语法前向声明的基本语法如下:classMyClass;......
  • useImperativeHandle, forwardRef ,使方法和属性应暴露给父组件
    useImperativeHandle是React中的一个Hook,用于自定义组件的实例值。它通常与forwardRef一起使用,允许你在父组件中通过引用(ref)访问子组件的特定实例方法或属性。以下是对useImperativeHandle的详细解析。1、语法importReact,{useImperativeHandle,forwardRef......
  • std::move()与std::forward()
    在C++中,右值、移动构造函数、std::move()、和std::forward()都是与优化和内存管理相关的概念,特别是在避免不必要的拷贝时有很大作用。1.右值(Rvalue)右值通常是表达式中不具有持久性的临时对象。它是不能通过变量名来引用的值,通常出现在赋值语句的右侧。常见的右值有:字面值:如5......
  • git pull 出现non-fast-forward的错误
    1.gitpullorigindaily_liu_0909:liu_0909出现non-fast-forward的错误,证明您的本地库跟远程库的提交记录不一致,即你的本地库版本需要更新2.gitresethead^若你的本地库已经commit最新的更改,则需要回到你的版本更改之前的版本3.gitadd.gitstash版本回退之后,您的更改......
  • Flink Forward Asia 2024 议题征集令|探索实时计算新边界
    简介:FlinkForwardAsia2024将于11月29日至30日在上海举行,现公开征集议题。作为ApacheFlink社区的重要年度活动,大会旨在汇集行业最佳实践和技术动态。议题覆盖流式湖仓、流批一体、Al大模型、生产实践等方向,并特别关注ApachePaimon和FlinkCDC等社区项目。所有议题将由专......
  • Nginx $remote_addr和$proxy_add_x_forwarded_for变量的实现
    $remote_addr代表客户端IP。注意,这里的客户端指的是直接请求Nginx的客户端,非间接请求的客户端。假设用户请求过程如下:用户客户端--发送请求->Nginx1--转发请求-->Nginx2->后端服务器那么,默认情况下,针对Nginx1而言,$remote_addr为用户客户端IP,对Nginx2而言,$remote_addr则为Ngi......
  • C++学习随笔——C++11的array、forward_list、tuple的用法
    1.std::arraystd::array是C++11引入的一个封装了原生数组的容器,它结合了C++标准库容器的优点和C风格数组的效率。#include<array>#include<iostream>intmain(){std::array<int,5>arr={1,2,3,4,5};//初始化一个大小为5的数组//访问元素......
  • forward_list
    forward_listforward_list容器以单链表的形式存储元素。forward_list的模板定义在头文件forward_list中。fdrward_list和list最主要的区别是:它不能反向遍历元素;只能从头到尾遍历。forward_list的单向链接性也意味着它会有一些其他的特性:无法使用反向迭代器。只能从它......