首页 > 其他分享 >LLM相关损失函数

LLM相关损失函数

时间:2024-05-22 10:53:45浏览次数:25  
标签:KL 函数 torch 损失 kl LLM d1 d2 tensor

信息熵:

信息熵torch代码
event = {'a':2 , 'b':2, 'c':4}      # 信息熵分:1.5
event2 = {'a':1 , 'b':1, 'c':1}     # 信息熵分:1.585
p_e = [ v/sum(event.values()) for v in event.values() ]
en_e = [ item*torch.log2(torch.tensor(item)) for item in p_e ]
print(en_e)
info_entropy = -torch.sum(torch.tensor(en_e))

相对熵:KL散度

  • KL:衡量两个分布的差异
  • KL越大:分布差异大 / 拟合损失大 / 模型优化难度大
  • KL(P||Q)通常不等于KL(Q||P),概率分布一样,两者才会相等且为0。分别表示用分布 Q 拟合 P
  • KL(ASR-wenet端到端识别中):模型生成分布为(T,D,P)真实标签(T,D,1/n)
1维tensor计算
import torch.nn.functional as F
x = torch.tensor([0.5, 0.5])
y = torch.tensor([0.2, 0.8])
logp_x = torch.softmax(x, dim=-1)
p_y = torch.softmax(y, dim=-1)
kl_mean = F.kl_div(logp_x, p_y, reduction='mean')
kl_sum = F.kl_div(logp_x, p_y, reduction='sum')
kl_default = F.kl_div(logp_x, p_y )

d1 = [0.5, 0.5]
d2 = [0.2, 0.8]
d1 = torch.softmax( torch.tensor(d1), dim=-1 )
d2 = torch.softmax( torch.tensor(d2), dim=-1 )
def kl_self(d1, d2):
    return torch.tensor( [ d2[id]*(torch.log(d2[id])-v) for id, v in enumerate(d1) ] )
kl_self(logp_x, p_y).sum()
KL多维tensor计算(摘自wenet,与手写不一致可能是softmax部分)
d1 = [0.5, 0.5]
d2 = [0.2, 0.8]
kl = torch.nn.KLDivLoss(reduction="none")
kl( torch.tensor(d2) , torch.tensor(d1) )
# 手写
d1 = torch.softmax( torch.tensor(d1), dim=-1 )
d2 = torch.softmax( torch.tensor(d2), dim=-1 )
torch.tensor( [ d2[id]*(torch.log(d2[id])-v) for id, v in enumerate(d1) ] )

标签:KL,函数,torch,损失,kl,LLM,d1,d2,tensor
From: https://www.cnblogs.com/lhx9527/p/18204599

相关文章

  • LLM-文心一言:modbus、opc、can、mqtt协议
    Modbus、OPC、CAN和MQTT都是不同的通信协议,它们在工业自动化、物联网和其他领域有着广泛的应用。以下是对这些协议的简要介绍:Modbus:Modbus是一种串行通信协议,由Modicon公司(现为施耐德电气的一部分)在1979年提出,用于可编程逻辑控制器(PLC)之间的通信。它已经成为工业领域通信协议的......
  • Hooking linux内核函数(一)
    本文是《HookingLinuxKernelFunctions,Part1:LookingforthePerfectSolution》的翻译文章。前言我们最近参与了一个Linux系统安全相关项目,需要hooking几个重要的Linux内核函数调用,例如打开文件和启动进程,并利用它来启用系统活动监控并抢先阻止可疑进程。最后,我们发明......
  • Hooking linux内核函数(二):如何使用Ftrace hook函数
    本文是《HookingLinuxKernelFunctions,Part2:HowtoHookFunctionswithFtrace》的翻译文章前言Ftrace是一个用于跟踪Linux内核函数的Linux内核框架。但是,当我们尝试启用系统活动监控以阻止可疑进程时,我们的团队设法找到了一种使用ftrace的新方法。事实证明,ftrace允许......
  • 欧拉函数
    一、欧拉函数定义\([1,n]\)中与\(n\)互质的数的个数,称为欧拉函数,记为\(\varphi(n)\)。互质的定义:对于正整数\(a\)和\(b\),若\(gcd(a,b)=1\),则\(a\)和\(b\)互质。性质若\(p\)是质数,则\(\varphi(p)=p-1\)。证:因为\(p\)是质数,所以因数只有\(1\)和\(p\)。......
  • python中那些双下划线开头得函数和变量
    Python中下划线---完全解读Python用下划线作为变量前缀和后缀指定特殊变量_xxx不能用frommoduleimport*导入__xxx__系统定义名字__xxx类中的私有变量名核心风格:避免用下划线作为变量名的开始。因为下划线对解释器有特殊的意义,而且是内建标识符所使用的符号,我们建议程......
  • 不同场景下的构造函数调用
    本文为对不同场景下的构造函数调用进行跟踪。构造函数默认情况下,在C++之后至少存在六个函数默认构造/析构函数,复制构造/复制赋值,移动构造/移动赋值。以下代码观测发生调用的场景#include<iostream>structFoo{Foo():fd(0){std::cout<<"Foo::Foo()this="<<......
  • 再探虚函数
    虚函数是一种成员函数,其行为可以在派生类中被覆盖,支持动态调用派发。使用示例代码如下:extern"C"{//避免operator<<多次调用,简化汇编代码voidprintln(constchar*s){std::cout<<s<<std::endl;}}void*operatornew(size_tn){void*p=malloc(n);......
  • c++ 结构体的构造函数
    结构体中构造函数1、不使用构造函数1#include<iostream>23structstudent{45intage;6std::stringgender;78}Liu;910intmain(){11Liu.age=20;12Liu.gender="man";1314std::cout<<Liu.age<......
  • Flink富函数
      富函数是DataStreamAPI提供的函数接口,Flink的函数都有它的Rich版本,它与其他函数不同的是,富函数可以获取到运行环境上下文,初始化参数,拥有生命周期方法等,可通过它进行自定义复杂功能。我们常见的如RichMapFunction、RichFilterFunction等。    富函数的生命周期主要通过......
  • 【代码】--库函数学习 temperature.c
    1. 封装的函数   用到了内核中的hwmon子系统,   hwmon子系统作为Linux内核中的一个子系统,用于监控硬件传感器的状态(设备的温度、电压和风扇转速)和提供对硬件传感器的访问接口。   在应用层,对传感器信息的读取,本质上是对驱动中hwmon子系统在注册传感器设备时所......