首页 > 其他分享 >均方误差损失函数(MSE)和交叉熵损失函数详解

均方误差损失函数(MSE)和交叉熵损失函数详解

时间:2024-12-21 21:32:44浏览次数:4  
标签:loss 函数 交叉 样本 损失 均方 类别

为什么需要损失函数

前面的文章我们已经从模型角度介绍了损失函数,对于神经网络的训练,首先根据特征输入和初始的参数,前向传播计算出预测结果,然后与真实结果进行比较,得到它们之间的差值。

损失函数又可称为代价函数或目标函数,是用来衡量算法模型预测结果和真实标签之间吻合程度(误差)的函数。通常会选择非负数作为预测值和真实值之间的误差,误差越小,则模型越好。

     有了这个损失函数,我们便可以采用优化算法更新网络参数,使得训练样本的平均损失最小。

而损失函数根据任务的不同,也可以分为不同的类型,下面进行介绍。

 

均方误差损失函数(MSE)

其中f(xi)是第i个样本的模型预测值,Yi是第i个样本的真实标签值,二者差值求平方,一共有n个样本,平方和求平均。

在回归问题中,均方误差损失函数用于度量样本点到回归曲线的距离,通过最小化平方损失使样本点可以更好地拟合回归曲线。由于无参数、计算成本低和具有明确物理意义等优点,MSE已成为一种优秀的距离度量方法。尽管MSE在图像和语音处理方面表现较弱,但它仍是评价信号质量的标准。

代码实现:

import numpy as np

# 自定义实现

def MSELoss(x:list,y:list):

    """    x:list,代表模型预测的一组数据    y:list,代表真实样本对应的一组数据    """

    assert len(x)==len(y)

    x=np.array(x)

    y=np.array(y)

    loss=np.sum(np.square(x - y)) / len(x)

    return loss

#计算过程举例x=[1,2]y=[0,1]loss=((1-0)**2 + (2-1)**2)÷2=(1+1)÷2=1

# pytorch版本

loss = nn.MSELoss()

predict = torch.randn(3, 5, requires_grad=True)

target = torch.randn(3, 5)

output = loss(predict, target)

从代码中可以看到,MSELoss需要的两个参数分别是真实标签值和模型预测值,两者可以是任意形状的张量,但二者形状和维度需要一致。就是说每个样本的预测值和标签值可以是任意维度的张量,这点要注意,在实际应用中时要认真考虑标签的形状。

 

交叉熵损失

pytorch中的CrossEntropyLoss()函数实际就是先把输出结果进行sigmoid,随后再放到传统的交叉熵函数中,就会得到结果。

交叉熵是信息论中的一个概念,最初用于估算平均编码长度,引入机器学习后,用于评估当前训练得到的概率分布与真实分布的差异情况。为了使神经网络的每一层输出从线性组合转为非线性逼近,以提高模型的预测精度,在以交叉熵为损失函数的神经网络模型中一般选用tanh、sigmoid、softmax或ReLU作为激活函数

交叉熵损失函数刻画了实际输出概率与期望输出概率之间的相似度,也就是交叉熵的值越小,两个概率分布就越接近,特别是在正负样本不均衡的分类问题中,常用交叉熵作为损失函数。目前,交叉熵损失函数是卷积神经网络中最常使用的分类损失函数,它可以有效避免梯度消散。在二分类情况下也叫做对数损失函数。

一般的交叉熵用数学公式表示是:

-Q(x) log P(x)

其中Q(x)是真实值,P(x)是预测值。

当p(x)和Q(x)是矩阵的时候,就分别对其计算,然后求和即可

在pytorch中的交叉熵损失CrossEntropyLoss 包含了两部分,softmax和交叉熵计算,下面分别介绍这两部分

假设有 N 个样本,每个样本属于 C 个类别之一。对于第 i 个样本,它的真实类别标签为 y_i,模型的预测输出 logits 为xi​=(xi1​,xi2​,…,xiC​),其中xic表示第i个样本在第c 类别上的原始输出分数(logits)(注意这里是预测分数值,不是概率值)。

交叉熵损失的计算步骤如下:

(1)预测概率分布
对 logits 进行 softmax 操作,将预测输出其转换为概率分布:

其中 pic表示第i个样本属于第c类别的预测概率。

   此时预测输出的概率分布是f(xi)=(pi1,pi2,…,piC)

  1. 真实概率分布:

对于样本i,其真实分布会根据归属的类别自动创建一个one-hot概率分布,即所属类别的位置为1,其它均为0,则会输出一个one-hot概率分布Q(xi)=(qi1,qi2,…,qiC)。比如5个类别,第i个样本的真实类别为3,则Q(xi)[0,0,1,0,0]。

实际计算的时候不难发现target中为0经过乘法都是0了,因此最后只剩下正确类型的这个损失差距 最后公式可以演变成 - log Q(x)

(3)负对数似然(Negative Log-Likelihood)
对于单个样本,计算负对数似然:

其中是第i 个样本的交叉熵损失,但事实上,只在真实类别位置处概率为1,其余位置均为0,因此,可以进一步简化为

 其中,yi代表第i个样本在真实类别j=yi处的预测概率。其本质是利用真实概率分布筛选了预测概率分布在真实类别的概率值,并求负对数似然。

对于N个样本,则对这N个样本的交叉熵损失函数求和再求平均即可。

  1. 代码解析

cross_loss = torch.nn.CrossEntropyLoss(reduction='none')

#注意这里的预测输入是N*C,其中N是样本数,C是类别数,此时还不是概率,所以使用交叉熵损失函数的网络最后不需要softmax,损失函数自带。

input = torch.tensor([[4, 14, 19, 15],

                       [18, 6, 14, 7],

                       [18, 5, 3, 16]], dtype=torch.float)

#真实标签是每个样本的类别(1*N),api会自动生成one-hot概率分布(N*C)

    target = torch.tensor([0, 3, 2])

  #然后计算损失函数值

  loss = cross_loss(input, target)

    torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)

参数

    参数说明:

1. weight:

  • CE 和 BCE 系列都有此参数,用于为每个类别的 loss 设置权重,常用于类别不均衡问题;
  • weight 必须是 float 类型的 1D tensor,长度和类别长度一致:weight = torch.from_numpy(np.array([0.6, 0.2, 0.2])).float().to(device)
  • 注意:weight 加起来未必一定要等于 1,类 c 对应的 weight 为 W_c = (N-N_c) / N,数目越多的类,weight 越小,weight 越大,此类得到的 loss 被放大;

2. ignore_index:

  • 其中 BCE 系列没有此参数,此参数用于指定忽略某些类别的 loss;

3. size_average:

  • 该参数指定 loss 是否在一个 batch 内平均,即是否除以 N,目前此参数已经被弃用

4. reduce:

  • 目前此参数已经被弃用

5. reduction:

  • 此参数在新版本中是为了取代 ”size_average“ 和 "reduce" 参数的;
  • mean (default):返回 N 个 loss 的平均值;
  • sum:返回 N 个 loss 的 sum;
  • None:直接返回一个 batch 中的 N 个 loss;

6. pos_weight:

  • 只有 BCEWithLogits 系列有次参数;
  • 与 weight 参数的区别是:WIP;

(5)nn.CrossEntropyLoss=nn.LogSoftmax(dim=1)+nn.NLLLoss()

(5)多维交叉熵

文本类数据通常是三维数据,预测通常是(batch_size,seq_length,num_vocab_size),而target是(batch_size,seq_length),此时需要预测的形状,通常使用permute操作成 (batch_size,num_vocab_size,seq_length)

参考资料

https://zhuanlan.zhihu.com/p/261059231

交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数 - 不愿透漏姓名的王建森 - 博客园

MSELoss — PyTorch 2.5 documentation

标签:loss,函数,交叉,样本,损失,均方,类别
From: https://blog.csdn.net/weixin_42251091/article/details/144635547

相关文章

  • 【Web】0基础学Web—函数、箭头函数、函数闭包、函数参数、js作用域、字符串
    0基础学Web—函数、箭头函数、函数闭包、函数参数、js作用域、字符串函数函数声明函数调用函数事件调用函数匿名函数立即执行函数箭头函数函数闭包函数参数js作用域字符串字符串创建字符串方法字符串拼接字符串截取去除字符串首尾空格遍历其他函数function函数名(......
  • MySQl常用函数解析
    1.LEAST函数:返回多个值中的最小值LEAST(value1,value2,...,valueN)2.GREATEST函数:返回多个值中的最大值GREATEST(value1,value2,...,valueN)字符串比较规则:从字符串的第一个字符开始,逐个字符进行比较,直到找到不同的字符为止。如果字符串的前几个字符相同......
  • 《操作系统真相还原》实验记录1.2——print.S打印函数
    一、print.S文件说明put_char函数(每次只打印一个字符)是各种打印函数的核心1.1功能说明put_char函数的处理流程备份寄存器现场;获取光标坐标值,光标坐标值是下一个可打印字符的位置;为了在光标处打印字符,需要读取光标坐标寄存器,获取光标坐标值。获取待打印的字符;......
  • 巧记斜边函数hypot
    hypot是一个数学函数,源于英文"hypotenuse(斜边)",hypot(a,b)返回直角边边长为a、b的直角三角形(right-angledtriangle)的斜边长度。该函数定义在<math.h>头文件中,其功能相当于sqrt(a*a+b*b)。函数原型:doublehypot( double x, double y );示例代码如下:#include<stdio.h>......
  • Python常用77个函数总结
    01、print()函数:打印字符串02、raw_input()函数:从用户键盘捕获字符03、len()函数:计算字符长度04、format(12.3654,‘6.2f’/‘0.3%’)函数:实现格式化输出05、type()函数:查询对象的类型06、int()函数、float()函数、str()函数等:类型的转化函数07、id()函数:获取对象的内......
  • 初始化列表和函数体内赋值的一个区别:
    先提出问题:看着乌漆嘛黑的代码,我的脑子在想运行结果为什么不是两“大”次复制构造函数,因为我认为传入参数这是第一步会调用复制构造函数,但是把参数赋值给类的实例对象的数据成员这也应该是复制构造函数啊。(这可是我花了几个小时才验证到的结果啊)这是两种不同的写法的不同输出......
  • CPP虚函数详解与实例
    CPP虚函数详解与实例在CMU_15445的Project3中大量使用了虚函数,抽象类的方法主要在Expression(表达式)以及Executor(Plan_Node的执行)中,在完成Part1的时候仅关注了功能的实现,还没有完全搞清楚为什么要使用虚函数以及抽象类,以及虚函数背后的原理,本次补充一下.......
  • C语言函数指针实用总结——高阶篇
     在C语言的江湖中,函数指针犹如一门深不可测的绝学,掌握它,你将能够游刃有余地处理各种复杂场景。今天,我们将深入探讨函数指针的高级用法,并通过一系列案例,让你领略其无限魅力。本文为高阶篇,适合已经有一定函数指针基础的读者。一、函数指针深度解析1.函数指针数组与指向函数......
  • 匿名函数和命名函数的区别?
    在前端开发中,匿名函数和命名函数是两种常见的函数定义方式,它们之间存在几个关键的区别。定义方式:命名函数:通过function关键字后跟函数名称来定义,如functionmyFunction(){...}。命名函数可以在其被定义之前的代码中被调用,这是由于JavaScript的变量提升机制。匿名函数:没有明......
  • 标准IO相关函数接口
    size_tfwrite(constvoid*ptr,size_tsize,size_tnmemb,           FILE*stream);功能:向文件中写入指定大小的nmemb个元素参数:    ptr:要写入数据的首地址    size:写入的每个元素的大小    nmemb:要写入的元素的个数......