首页 > 其他分享 >《深度学习》PyTorch 常用损失函数原理、用法解析

《深度学习》PyTorch 常用损失函数原理、用法解析

时间:2024-09-18 21:49:20浏览次数:14  
标签:loss 函数 nn 标签 torch 损失 用法 PyTorch 解析

目录

一、常用损失函数

1、CrossEntropyLoss(交叉熵损失)

        1)原理

        2)流程

        3)用法示例

2、L1Loss(L1损失/平均绝对误差)

        1)原理

        2)用法示例

3、NLLLoss(负对数似然损失)

        1)原理

        2)用法示例

4、 MSELoss(均方误差损失)

        1)定义

        2)用法示例

5. BCELoss(二元交叉熵损失)

        1)定义

        2)用法示例

二、总结常用损失函数

        1、nn.CrossEntropyLoss:交叉熵损失函数

        2、nn.MSELoss:均方误差损失函数

        3、nn.L1Loss:平均绝对误差损失函数

        4、nn.BCELoss:二元交叉熵损失函数

        5、nn.NLLLoss:负对数似然损失函数


一、常用损失函数

1、CrossEntropyLoss(交叉熵损失)

        1)原理

                交叉熵损失是一种常用于分类问题的损失函数,它衡量的是模型输出的概率分布与真实标签分布之间的差异

                在多分类问题中,模型会输出每个类别的预测概率。交叉熵损失通过计算真实标签对应类别的负对数概率评估模型的性能。在实际应用中,nn.CrossEntropyLoss内部会对logits(即未经softmax的原始输出)应用softmax函数,将其转换为概率分布,然后计算交叉熵。

                例如:

                        假设有一个多类别分类任务,共有C个类别。对于每个样本,模型会输出一个包含C个元素的向量,其中每个元素表示该样本属于对应类别的概率。而真实标签是一个C维的向量,其中只有一个元素为1,其余元素均为0,表示样本的真实类别。

        2)流程

                首先,将模型输出的向量通过softmax函数进行归一化,将原始的概率值转换为概率分布。即对模型输出的每个元素进行指数运算,然后对所有元素求和,最后将每个元素除以总和,得到归一化后的概率分布。

                然后,将归一化后的概率分布与真实标签进行比较,计算两者之间的差异。交叉熵损失函数的计算公式为: -sum(y * log(p))  ,其中y是真实标签的概率分布,p是模型输出的归一化后的概率分布。该公式表示真实标签的概率分布与模型输出的归一化后的概率分布之间的交叉熵。

                最后,将每个样本的交叉熵损失值进行求和或平均,得到整个批次的损失值。

       

        3)用法示例
import torch  
import torch.nn as nn  
  
# 假设有一个模型输出的logits和一个真实的标签  
logits = torch.randn(10, 5, requires_grad=True)  # 10个样本,5个类别  
labels = torch.randint(0, 5, (10,))  # 真实标签,每个样本对应一个类别索引  
  
# 创建CrossEntropyLoss实例  
loss_fn = nn.CrossEntropyLoss()  
  
# 计算损失  
loss = loss_fn(logits, labels)  
  
# 反向传播  
loss.backward()

2、L1Loss(L1损失/平均绝对误差)

        1)原理

                L1损失,也称为平均绝对误差(MAE),计算的是预测值与真实值之差绝对值平均值

                L1损失对异常值(即远离平均值的点)的敏感度较低,因为它通过绝对值来度量误差,而绝对值函数在零点附近是线性的。

       

        2)用法示例
loss_fn = nn.L1Loss()  
predictions = torch.randn(3, 5, requires_grad=True)  # 预测值  
targets = torch.randn(3, 5)  # 真实值  
  
# 计算损失  
loss = loss_fn(predictions, targets)  
  
# 反向传播  
loss.backward()

3、NLLLoss(负对数似然损失)

        1)原理

                负对数似然损失(NLLLoss)通常与log_softmax一起使用,用于多分类问题。它计算的是目标类别负对数概率

                NLLLoss期望的输入是对数概率(即已经通过log_softmax处理过的输出),然后计算目标类别的负对数概率。

        2)用法示例
# 假设已经计算了logits  
logits = torch.randn(3, 5, requires_grad=True)  
  
# 应用log_softmax获取对数概率(在PyTorch中,通常直接使用CrossEntropyLoss)  
log_probs = torch.log_softmax(logits, dim=1)  
  
# 创建NLLLoss实例  
loss_fn = nn.NLLLoss()  
  
# 真实标签  
labels = torch.tensor([1, 0, 4], dtype=torch.long)  
  
# 计算损失  
loss = loss_fn(log_probs, labels)  
  
# 反向传播  
loss.backward()

                在实际应用中,直接使用CrossEntropyLoss更为常见,因为它内部集成了softmax和NLLLoss的计算。

4、 MSELoss(均方误差损失)

        1)定义

                均方误差损失(MSE)计算的是预测值与真实值之差的平方的平均值

                MSE通过平方误差来放大较大的误差,从而给予模型更大的惩罚。它是回归问题中最常用的损失函数之一。

        2)用法示例
loss_fn = nn.MSELoss()  
predictions = torch.randn(3, 5, requires_grad=True)  # 预测值  
targets = torch.randn(3, 5)  # 真实值  
  
# 计算损失  
loss = loss_fn(predictions, targets)  
  
# 反向传播  
loss.backward()

5.BCELoss(二元交叉熵损失)

        1)定义

                二元交叉熵损失(BCE)用于二分类问题,计算的是预测概率与真实标签(0或1)之间的交叉熵

                BCE通过计算真实标签对应类别的负对数概率来评估模型的性能。它适用于输出概率的模型,但并不要求输入必须经过sigmoid函数(尽管在实践中很常见)。

        2)用法示例
loss_fn = nn.BCELoss()  
  
# 假设预测值已经通过sigmoid函数(虽然不是必需的)  
predictions = torch.sigmoid(torch.randn(3, requires_grad=True))  
  
# 真实标签  
targets = torch.empty(3).random_(2).float()  # 生成0或1的随机值  
  
# 计算损失  
loss = loss_fn(predictions, targets)  
  
# 反向传播  
loss.backward()

二、总结常用损失函数

        1、nn.CrossEntropyLoss:交叉熵损失函数

                主要用于多分类问题。它将模型的输出(logits)与真实标签进行比较,并计算损失。

        2、nn.MSELoss:均方误差损失函数

                用于回归问题。它计算模型输出与真实标签之间的差异的平方,并返回平均值。

        3、nn.L1Loss:平均绝对误差损失函数

                也称为L1损失。类似于MSELoss,但是它计算模型输出与真实标签之间的差异的绝对值,并返回平均值。

        4、nn.BCELoss:二元交叉熵损失函数

                用于二分类问题。它计算二分类问题中的模型输出与真实标签之间的差异,并返回损失。

        5、nn.NLLLoss:负对数似然损失函数

                主要用于多分类问题。它首先应用log_softmax函数(log_softmax(x) = log(softmax(x)))将模型输出转化为对数概率,然后计算模型输出与真实标签之间的差异。

标签:loss,函数,nn,标签,torch,损失,用法,PyTorch,解析
From: https://blog.csdn.net/qq_64603703/article/details/142343950

相关文章

  • Unity UI控件用法汇总
    利用LoopListView实现Banner循环列表,且默认中间节点为默认节点:  1.给ScrollRect节点添加LoopListView组件,并勾选ItemSnapEnable为true。  2.通过LoopListView.InitListView初始化时,totalCount需要传-1.  3.OnGetItemByIndex的回调参数index以(Int32.MinValue,Int32.MaxVa......
  • 接收网络包的过程——从硬件网卡解析到IP层
    当一些网络包到来触发了中断,内核处理完这些网络包之后,我们可以先进入主动轮询poll网卡的方式,主动去接收到来的网络包。如果一直有,就一直处理,等处理告一段落,就返回干其他的事情。当再有下一批网络包到来的时候,再中断,再轮询poll。这样就会大大减少中断的数量,提升网络处理的效率,这......
  • MySQL数据库select语句详细用法三(子查询及其select练习)
    SELECT*FROMstudent2WHEREage> (SELECTageFROMstudent2WHERENAME='欧阳丹丹')首先解释一下括号中的代码,意思是在查询student2中的name为欧阳丹丹的人的名字,然后解释一下整个语句的意思:在括号中查询出来的字段中再次进行查询在student2中age大于name为欧阳丹丹的......
  • 中国企业数据资产入表情况跟踪全文解析
    随着数字经济的蓬勃发展,数据已成为企业的重要资产。2023年8月,中国财政部发布的《企业数据资源相关会计处理暂行规定》标志着数据资源正式纳入会计核算体系,为企业数据资产的管理和运用提供了政策支持。本文将详细分析2024年第一季度中国企业数据资产入表的现状、挑战与未来趋势。一......
  • vulnhub(9):sickos1.2(深挖靶机的各个细节、文件管道反弹shell详解、base64编码反弹shell
    端口nmap主机发现nmap-sn192.168.148.0/24​Nmapscanreportfor192.168.148.131Hostisup(0.00020slatency).​131是新出现的机器,他就是靶机nmap端口扫描nmap-Pn192.168.148.131-p---min-rate10000-oAnmap/scan扫描开放端口保存到nmap/scan下​......
  • 菜鸟笔记之PWN入门(1.1.0)ELF 文件格式和程序段解析(简版)
    ELF(ExecutableandLinkableFormat):是一种用于可执行文件、目标文件和库的文件格式,类似于Windows下的PE文件格式。ELF主要包括三种类型的文件:可重定位文件(relocatable):编译器和汇编器产生的 .o 文件,由 Linker 处理。可执行文件(executable): Linker ......
  • Javaweb之SpringBootWeb案例之修改员工的修改回显的详细解析
     3.修改员工需求:修改员工信息编辑在进行修改员工信息的时候,我们首先先要根据员工的ID查询员工的信息用于页面回显展示,然后用户修改员工数据之后,点击保存按钮,就可以将修改的数据提交到服务端,保存到数据库。具体操作为:根据ID查询员工信息保存修改的员工信息3.1查询回显3.1.1接口......
  • 想成为Admineloper?Salesforce全新职业解析,机会就在眼前!
    每个Salesforce管理员在思考自己的职业生涯时,可能都会想到:下一步是往架构师,或者开发,还是咨询的方向发展。无论哪种职业规划,都需要培养新的技能。由于角色职责、团队、客户需求等的变化,许多管理员在晋升之前就开始培养这种新技能。管理员可能需要学习Apex或利用现有知识,转向传统......
  • Metasploit Framework (MSF) 使用指南 - 第一篇:介绍与基础用法
    引言MetasploitFramework(MSF)是一款功能强大的开源安全漏洞检测工具,被广泛应用于渗透测试中。它内置了数千个已知的软件漏洞,并持续更新以应对新兴的安全威胁。MSF不仅限于漏洞利用,还包括信息收集、漏洞探测和后渗透攻击等多个环节,因此被安全社区誉为“可以黑掉整个宇宙”的工具。......
  • CISP-PTE综合靶场解析,msf综合利用MS14-058【附靶场环境】
    CISP-PTE综合靶场解析,msf综合利用MS14-058【附靶场环境】前言需要靶场的朋友们,可以在后台私信【pte靶场】,有网安学习群,可以关注后在菜单栏选择学习群加入即可信息收集题目要求:给定一个ip,找到3个KEY。nmap扫描,为了节省时间,这里改了端口,就不使用-p-全端口扫描了,源靶机在20000+......