首页 > 其他分享 >hello

hello

时间:2024-08-05 16:42:36浏览次数:6  
标签:loss target dim torch softmax logits hello

点击查看代码
    def forward(self, output_logits, target, extra_info=None):
        if extra_info is None:
            return self.base_loss(output_logits, target)  # output_logits indicates the final prediction

        loss = 0
        temperature_mean = 1
        temperature = 1  
        # Obtain logits from each expert
        epoch = extra_info['epoch']
        num = int(target.shape[0] / 2)

        expert1_logits = extra_info['logits'][0] + torch.log(torch.pow(self.prior, -0.5) + 1e-9)      #head

        expert2_logits = extra_info['logits'][1] + torch.log(torch.pow(self.prior, 1) + 1e-9)         #medium

        expert3_logits = extra_info['logits'][2] + torch.log(torch.pow(self.prior, 2.5) + 1e-9)       #few

        teacher_expert1_logits = expert1_logits[:num, :]  # view1
        student_expert1_logits = expert1_logits[num:, :]  # view2

        teacher_expert2_logits = expert2_logits[:num, :]  # view1
        student_expert2_logits = expert2_logits[num:, :]  # view2

        teacher_expert3_logits = expert3_logits[:num, :]  # view1
        student_expert3_logits = expert3_logits[num:, :]  # view2

        teacher_expert1_softmax = F.softmax((teacher_expert1_logits) / temperature, dim=1).detach()
        student_expert1_softmax = F.log_softmax(student_expert1_logits / temperature, dim=1)

        teacher_expert2_softmax = F.softmax((teacher_expert2_logits) / temperature, dim=1).detach()
        student_expert2_softmax = F.log_softmax(student_expert2_logits / temperature, dim=1)

        teacher_expert3_softmax = F.softmax((teacher_expert3_logits) / temperature, dim=1).detach()
        student_expert3_softmax = F.log_softmax(student_expert3_logits / temperature, dim=1)

        teacher1_max, teacher1_index = torch.max(F.softmax((teacher_expert1_logits), dim=1).detach(), dim=1)
        student1_max, student1_index = torch.max(F.softmax((student_expert1_logits), dim=1).detach(), dim=1)

        teacher2_max, teacher2_index = torch.max(F.softmax((teacher_expert2_logits), dim=1).detach(), dim=1)
        student2_max, student2_index = torch.max(F.softmax((student_expert2_logits), dim=1).detach(), dim=1)

        teacher3_max, teacher3_index = torch.max(F.softmax((teacher_expert3_logits), dim=1).detach(), dim=1)
        student3_max, student3_index = torch.max(F.softmax((student_expert3_logits), dim=1).detach(), dim=1)

        # distillation
        partial_target = target[:num]
        kl_loss = 0
        if torch.sum((teacher1_index == partial_target)) > 0:
            kl_loss = kl_loss + F.kl_div(student_expert1_softmax[(teacher1_index == partial_target)],
                                         teacher_expert1_softmax[(teacher1_index == partial_target)],
                                         reduction='batchmean') * (temperature ** 2)

        if torch.sum((teacher2_index == partial_target)) > 0:
            kl_loss = kl_loss + F.kl_div(student_expert2_softmax[(teacher2_index == partial_target)],
                                         teacher_expert2_softmax[(teacher2_index == partial_target)],
                                         reduction='batchmean') * (temperature ** 2)

        if torch.sum((teacher3_index == partial_target)) > 0:
            kl_loss = kl_loss + F.kl_div(student_expert3_softmax[(teacher3_index == partial_target)],
                                         teacher_expert3_softmax[(teacher3_index == partial_target)],
                                         reduction='batchmean') * (temperature ** 2)

        loss = loss + 0.6 * kl_loss * min(extra_info['epoch'] / self.warmup, 1.0)

        # expert 1
        loss += self.base_loss(expert1_logits, target)

        # expert 2
        loss += self.base_loss(expert2_logits, target)

        # expert 3
        loss += self.base_loss(expert3_logits, target)

        return loss

标签:loss,target,dim,torch,softmax,logits,hello
From: https://www.cnblogs.com/ZarkY/p/18343493

相关文章

  • dotnet hello world
    参考资料dotnet命令参考使用dotnettest和xUnit在.NET中对C#进行单元测试DeclaringInternalsVisibleTointhecsprojXUnit输出消息创建控制台项目#创建项目目录mdDotnetStudycdDotnetStudy#创建解决方案dotnetnewsln#创建控制台项目,-n:名称,--use......
  • 开发调试驱动helloworld
    开发调试驱动helloworldhttps://learn.microsoft.com/zh-cn/windows-hardware/drivers配置开发环境https://learn.microsoft.com/zh-cn/windows-hardware/drivers/download-the-wdk按照步骤依次安装VisualStudioCommunity、SDK、WDK这里的windbg界面更现代一点https://lea......
  • webservice hello
    一、//hello.hintns__hello(std::stringname,std::string&greeting);二、//helloclient.cpp#include"soapH.h"#include"ns.nsmap"#include<string>#include<iostream>usingnamespacestd;inthello(structsoap*so......
  • 信步漫谈之Android——HelloWorld
    目录目标1资源2第一个HelloWorld程序3项目结构说明3.1目录结构3.2结构说明4在App中添加日志后续补充参考资料目标学习搭建Android的开发环境sayhelloworld1资源官网教程:https://developer.android.com/courses开发工具AndroidStudio下载路径:https://d......
  • 信步漫谈之微信小程序——HelloWorld
    目录目标1资源2程序目录说明3第一个HelloWorld程序4真机调试参考资料(感谢)目标微信小程序开发环境sayhelloworld1资源微信官方文档:https://developers.weixin.qq.com/doc/微信开发者工具下载:https://developers.weixin.qq.com/miniprogram/dev/devtools/downloa......
  • Electron学习笔记(二)Hello World
    目录前言运行主进程创建界面使用窗口打开界面管理窗口的生命周期关闭所有窗口时退出应用(Windows&Linux)​如果没有窗口打开则打开一个窗口(macOS)使用预加载脚本访问渲染器的Node.js添加你自己的功能完整代码展示效果展示前言接上一篇文章Electron学习笔......
  • QOJ7899 Say Hello to the Future
    考虑先求出原序列的方案数设\(f_i\)表示\(1\simi\)被划分为若干区间的方案数,若一段区间合法当且仅当\(r-l+1\ge\max\{a_{l\simr}\}\),可以发现数据结构难以维护且由于不是最优性问题,考虑\(\texttt{cdq}\)分治优化对于每个分治中心\(m\),令\(mxL_i=\max\{a_{i\si......
  • 【Blog1】PyCharm写hello world
    创建项目选择项目类型和项目目录设置python编译器选择编译器新建Python文件编写代码并运行运行运行结果......
  • 最长的Hello, World!(C++)
    最长的Hello,World!(C++)#include<iostream>#include<string>#include<vector>#include<memory>#include<random>_<typenameT>classNode{public:Tdata;std::shared_ptr<Node<T>>next;Node(......
  • 最长的Hello, World!(Python)
    最长的Hello,World!(Python)(lambda_,__,___,____,_____,______,_______,________:getattr(__import__(True.__class__.__name__[_]+[].__class__.__name__[__]),().__class__.__eq__.__class__.__name__[:__]+().__iter__().__cla......