首页 > 其他分享 >模型蒸馏

模型蒸馏

时间:2023-04-10 10:26:34浏览次数:46  
标签:loss 蒸馏 kd 模型 student logits data teacher

  蒸馏过程:

for epoch in range(epochs):
    student_model.train()
    for batch, (data, target) in enumerate(train_loader):
        student_logits = student_model(data)
        // 教师不更新
        with torch.no_grad():
            teacher_logits = teacher_model(data)
        # student与label的loss
        loss_cri = F.cross_entropy(y_s, target)

        # student与teacher的loss
        loss_kd = soft_cross_entropy(student_logits/temperature, teacher_logits/temperature)
        ## kd loss
        #p_s = F.log_softmax(student_logits/kd_T, dim=1)
        #p_t = F.softmax(teacher_logits/kd_T, dim=1)
        #loss_kd = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / student_logits.shape[0]
        
        # total loss
        loss = alpha * loss_cri + beta * loss_kd
        loss.backward()
        optimizer.zero_grad()

github链接:https://github.com/huawei-noah/Pretrained-Language-Model/blob/master/TinyBERT/task_distill.py

参考文献

标签:loss,蒸馏,kd,模型,student,logits,data,teacher
From: https://www.cnblogs.com/3511rjzn/p/17301994.html

相关文章

  • 碉堡!“万物皆可分”标记模型上线「GitHub 热点速览」
    这周有个让人眼前一亮的图像识别模型segment-anything,它能精细地框出所有可见物体,它标记出的物体边界线清晰可见。如此出色的模型,自然获得了不少人的赞赏,开源没几天,就拿下了18k+的star,而上周开源不到48小时获得35k+star的推特推荐算法,本周也成功突破50k+关卡。依旧是......
  • 如何打造一个适合现代企业的完美业务模型?
    容量场景中,每个业务比例都要符合真实业务场景的比例。不符合,那场景的执行结果也没意义。但很多性能人员因为对业务模型的抽取过程不了解或拿不到具体数据,导致业务模型和生产业务场景不匹配,整个性能项目都变得无意义也有大量项目,并没有拿历史业务数据做统计,直接非常笼统地拍脑袋,给出......
  • stable diffusion打造自己专属的LORA模型
    通过Lora小模型可以控制很多特定场景的内容生成。但是那些模型是别人训练好的,你肯定很好奇,我也想训练一个自己的专属模型(也叫炼丹~_~)。甚至可以训练一个专属家庭版的模型(familymodel),非常有意思。将自己的训练好的Lora模型放到stableDiffusionlora目录中,同时配上美丽的封面图。......
  • stable diffusion打造自己专属的LORA模型
    通过Lora小模型可以控制很多特定场景的内容生成。但是那些模型是别人训练好的,你肯定很好奇,我也想训练一个自己的专属模型(也叫炼丹~_~)。甚至可以训练一个专属家庭版的模型(familymodel),非常有意思。将自己的训练好的Lora模型放到stableDiffusionlora目录中,同时配上美丽的封面图。......
  • HBase在进行模型设计时重点在什么地方?一张表中定义多少个Column Family最合适?为什么?
     锁屏面试题百日百刷,每个工作日坚持更新面试题。请看到最后就能获取你想要的,接下来的是今日的面试题: 1.Hbase中的memstore是用来做什么的?hbase为了保证随机读取的性能,所以hfile里面的rowkey是有序的。当客户端的请求在到达regionserver之后,为了保证写入rowkey的有序性,所以......
  • 浏览器层面优化前端性能(1):Chrom组件与进程/线程模型分析
    现阶段的浏览器运行在一个单用户,多合作,多任务的操作系统中。一个糟糕的网页同样可以让一个现代的浏览器崩溃。其原因可能是一个插件出现bug,最终的结果是整个浏览器以及其他正在运行的标签被销毁。现代操作系统已经非常健壮了,它让应用程序在各自的进程中运行和不会影响到其他程序......
  • 基于AutomationML的多模型数字孪生驱动方法
    【场景】:终于要毕业了,从一开始都不知道数字孪生是什么,在没有老师和师兄师姐铺路的情况下,一点点看论文,复现论文,找创新点,真的太苦了。这里将我这三年在数字孪生的研究简单记录、分享一下,希望能帮到某些人,水平有限,不喜勿喷。我所了解到的,现有数字孪生的主流实现方法大......
  • 千“垂”百炼:垂直领域与语言模型(1)
    UsingLanguageModelsinSpecificDomains(1)微信公众号版本:https://mp.weixin.qq.com/s/G24skuUbyrSatxWczVxEAg这一系列文章仍然坚持走“通俗理解”的风格,用尽量简短、简单、通俗的话来描述清楚每一件事情。本系列主要关注语言模型在垂直领域尝试的相关工作。Thisseries......
  • 13.颜色模型与转换
    本小节中将介绍几种OpenCV4中能够互相转换的常见的颜色模型,例如RGB模型、HSV模型、Lab模型、YUV模型以及GRAY模型,并介绍这几种模型之间的数学转换关系,以及OpenCV4中提供的这几种模型之间的变换函数。1、RGB颜色模型RGB颜色模型的命名方式是采用三种颜色的英文首字母组成,分......
  • VGG16模型-tensorflow实现的架构
    importtensorflowastffromtensorflow.keras.modelsimportSequentialfromtensorflow.keras.layersimportInputLayer,Dense,Flatten,Conv2D,MaxPooling2Dfromtensorflow.keras.optimizersimportAdamdefbuild_vgg16(input_shape,num_classes):model......