首页 > 其他分享 >Pytorch常用内置损失函数合集

Pytorch常用内置损失函数合集

时间:2024-12-20 15:58:43浏览次数:5  
标签:loss 内置 函数 nn 示例 损失 Pytorch 合集 fn

       PyTorch 提供了多种内置的损失函数,适用于不同的任务和场景。这些损失函数通常已经优化并实现了常见的归约方式(如 meansum),并且可以直接用于训练模型。以下是常见的 PyTorch 内置损失函数及其适用场景:

1. 均方误差损失(Mean Squared Error, MSE)

  • 类名nn.MSELoss

  • 公式

    其中 N 是样本数量,yi是真实值,y^i是预测值;

  • 适用场景

    • 回归问题:当目标是预测连续值时,MSE 是最常见的损失函数。它衡量预测值与真实值之间的平方差,并对较大的误差施加更大的惩罚。
    • 时间序列预测:在时间序列预测任务中,MSE 也常用于衡量模型的预测性能。
  • 示例代码

    loss_fn = nn.MSELoss()

2. 二元交叉熵损失(Binary Cross-Entropy, BCE)

  • 类名nn.BCELoss

  • 公式

    其中 N是样本数量,yi 是真实标签(0 或 1),y^i 是预测的概率值(介于 0 和 1 之间)。

  • 适用场景

    • 二分类问题:当目标是将输入分为两个类别时,BCE 是常用的损失函数。它衡量预测概率与真实标签之间的差异。
    • 多标签分类:在多标签分类任务中,每个样本可以属于多个类别,BCE 可以用于每个标签的独立预测。
  • 注意事项

    • 预测值应为概率值(介于 0 和 1 之间)。如果你的模型输出是未经过激活函数的 logits,应该使用 nn.BCEWithLogitsLoss,它会自动应用 Sigmoid 激活函数。
  • 示例代码

loss_fn = nn.BCELoss()

 

3. 带逻辑斯蒂回归的二元交叉熵损失(BCE with Logits)

  • 类名nn.BCEWithLogitsLoss

  • 公式

    其中 σ 是 Sigmoid 函数,y^i 是模型输出的 logits(未经过 Sigmoid 激活的值)。

  • 适用场景

    • 二分类问题:类似于 nn.BCELoss,但它直接接受未经过 Sigmoid 激活的 logits,并在内部应用 Sigmoid 激活函数。这可以提高数值稳定性。
    • 多标签分类:同样适用于多标签分类任务。
  • 优点

    • 数值更稳定,因为 Sigmoid 和 BCE 的计算是在同一层完成的,避免了梯度消失或爆炸的问题。
  • 示例代码

    loss_fn = nn.BCEWithLogitsLoss()

     

4. 多分类交叉熵损失(Cross Entropy Loss)

  • 类名nn.CrossEntropyLoss

  • 公式

    其中 yi​ 是真实标签(整数表示类别),y^i是模型输出的 logits(未经过 Softmax 激活的值)。

  • 适用场景

    • 多分类问题:当目标是将输入分为多个类别时,Cross Entropy 是常用的损失函数。它结合了 Softmax 激活函数和负对数似然损失(NLL),适合处理多分类任务。
    • 图像分类:在图像分类任务中,Cross Entropy 是最常用的选择。
  • 注意事项

    • 预测值应为 logits(未经过 Softmax 激活的值)。nn.CrossEntropyLoss 会在内部自动应用 Softmax 激活函数。
    • 真实标签应为整数表示的类别索引,而不是 one-hot 编码。
  • 示例代码

loss_fn = nn.CrossEntropyLoss()

 

5. 负对数似然损失(Negative Log Likelihood, NLL)

  • 类名nn.NLLLoss

  • 公式

    其中 yi​ 是真实标签(整数表示类别),pi​ 是预测的概率分布(经过 Softmax 激活后的值)。

  • 适用场景

    • 多分类问题:类似于 nn.CrossEntropyLoss,但 nn.NLLLoss 需要输入已经是经过 Softmax 激活的概率分布。因此,通常与 nn.LogSoftmax 一起使用。
    • 自定义激活函数:如果你希望在损失函数之前应用自定义的激活函数(如温度缩放的 Softmax),可以使用 nn.NLLLoss
  • 示例代码

# 使用 LogSoftmax 和 NLLLoss
m = nn.LogSoftmax(dim=1)
loss_fn = nn.NLLLoss()
output = m(logits)
loss = loss_fn(output, target)

 

6. L1 损失(L1 Loss, Mean Absolute Error, MAE)

  • 类名nn.L1Loss

  • 公式

    其中 N 是样本数量,yi 是真实值,y^i 是预测值。

  • 适用场景

    • 回归问题:与 MSE 类似,L1 损失用于回归任务,但它对异常值(outliers)不太敏感,因为它使用绝对差而不是平方差。
    • 鲁棒性要求较高的任务:当你希望模型对异常值具有更好的鲁棒性时,L1 损失是一个不错的选择。
  • 示例代码

 

loss_fn = nn.L1Loss()

7. Smooth L1 损失(Huber Loss)

  • 类名nn.SmoothL1Loss

  • 公式

    其中 x=yi−y^i​ 是预测值与真实值之间的差异。

  • 适用场景

    • 回归问题:Smooth L1 损失结合了 MSE 和 L1 损失的优点。对于小误差,它使用平方差(类似于 MSE),而对于大误差,它使用绝对差(类似于 L1)。这使得它对异常值具有一定的鲁棒性,同时保持了 MSE 的平滑性。
    • 目标检测:在目标检测任务中,Smooth L1 损失常用于回归边界框的坐标。
  • 示例代码

loss_fn = nn.SmoothL1Loss()

8. Kullback-Leibler 散度损失(KL Divergence)

  • 类名nn.KLDivLoss

  • 公式

    其中 P 是真实分布,Q 是预测分布;

  • 适用场景

    • 分布匹配:当目标是使预测分布尽可能接近真实分布时,KL 散度是一个常用的损失函数。它衡量两个分布之间的差异。
    • 生成对抗网络(GANs):在 GAN 中,KL 散度常用于衡量生成分布与真实分布之间的差异。
    • 变分自编码器(VAEs):在 VAE 中,KL 散度用于正则化潜在变量的分布,使其接近标准正态分布。
  • 注意事项

    • 输入应为对数概率分布(即经过 nn.LogSoftmax 处理的值),而目标应为概率分布。
  • 示例代码

loss_fn = nn.KLDivLoss(reduction='batchmean')

9. Hinge 损失(Hinge Loss)

  • 类名nn.HingeEmbeddingLoss

  • 公式

    其中 y 是真实标签(1 或 -1),y^​ 是预测值。

  • 适用场景

    • 二分类问题:Hinge 损失常用于支持向量机(SVM)中,尤其是在二分类任务中。它鼓励模型将正类和负类之间的间隔最大化。
    • 度量学习:在度量学习任务中,Hinge 损失用于鼓励相似样本之间的距离最小化,而不相似样本之间的距离最大化。
  • 示例代码

loss_fn = nn.HingeEmbeddingLoss()

10. Cosine 相似度损失(Cosine Embedding Loss)

  • 类名nn.CosineEmbeddingLoss

  • 公式

    其中 x1​ 和 x2​ 是两个输入向量,y 是标签(1 表示相似,-1 表示不相似);

  • 适用场景

    • 度量学习:Cosine Embedding Loss 用于度量学习任务,鼓励相似样本之间的余弦相似度最大化,而不相似样本之间的余弦相似度最小化。
    • 对比学习:在对比学习任务中,Cosine Embedding Loss 用于拉近正样本对的距离,推远负样本对的距离。
  • 示例代码

loss_fn = nn.CosineEmbeddingLoss(margin=0.5)

 

 

 

标签:loss,内置,函数,nn,示例,损失,Pytorch,合集,fn
From: https://blog.csdn.net/CITY_OF_MO_GY/article/details/144611435

相关文章

  • 【运维发布】蓝绿部署滚动更新金丝雀发布授权策略敏感数据保护内置监控功能外部监控工
    【运维发布】蓝绿部署滚动更新金丝雀发布前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站部署策略(续)蓝绿部署(Blue-GreenDeployment)蓝绿部署是一种零停机的发布策略,它通过维护两个完全相同的生产环境来实现。一个环境......
  • 基于yolov8的小麦麦穗检测系统,支持图像、视频和摄像实时检测【pytorch框架、python源
       更多目标检测、图像分类识别、目标追踪等项目可看我主页其他文章功能演示:基于yolov8的小麦麦穗检测系统,支持图像、视频和摄像实时检测【pytorch框架、python源码】_哔哩哔哩_bilibili(一)简介基于yolov8的小麦麦穗检测系统在pytorch框架下实现的,这是一个完整的项目,包括......
  • 【图像分类】数据集合集!
    本文将为您介绍经典、热门的数据集,希望对您在选择适合的数据集时有所帮助。1ImageProcessing-Python更新时间:2024-04-12访问地址: GitHub描述:该资源为作者在CSDN的撰写Python图像处理文章的支撑,主要是Python实现图像处理、图像识别、图像分类等算法代码实现。该系......
  • 【大数据】数据集合集!
    本文将为您介绍经典、热门的数据集,希望对您在选择适合的数据集时有所帮助。1bigdata-growth更新时间:2024-11-14访问地址: GitHub描述:大数据知识仓库涉及到数据仓库建模、实时计算、大数据、数据中台、系统设计、Java、算法等。数据集网址:https://github.com/colla......
  • 【物联网】数据集合集!
    本文将为您介绍经典、热门的数据集,希望对您在选择适合的数据集时有所帮助。1FastBee更新时间:2024-12-13访问地址: GitHub描述:FastBee开源物联网平台,简单易用,可用于搭建物联网平台以及二次开发和学习。适用于智能家居、智慧办公、智慧社区、农业监测、水利监测、工业......
  • USACO备考书籍合集
    USACO,全称UnitedStatesofAmericaComputingOlympiad,即美国计算机奥林匹克竞赛。以下是网上查到的关于USACO(美国计算机奥林匹克竞赛)的推荐书籍:一、国内推荐书籍有一种观点,冲击USACO铂金,无非就是“吃透”下面的前5本。这种论调是网上看得比较多的,但是老金也是刚刚查到,没看......
  • 深度学习笔记06-VGG16-Pytorch实现人脸识别
    本文通过调用预训练模型VGG16并进行模型微调,从而实现人脸识别。文章目录前言一、加载数据1.导入库2.导入数据3.定义transforms4.查看类别5.划分数据集6.加载数据二、调用VGG161.加载预训练模型2.模型微调三、训练模型1.训练函数2.测试函数3.动态学习率设......
  • 基于vgg16和efficientnet卷积神经网络的天气识别系统(pytorch框架) 图像识别与分类 前
    基于vgg16和efficientnet卷积神经网络的天气识别系统(pytorch框架)前端界面:flask+python,UI界面:pyqt5+python这是一个完整项目,包括代码,数据集,模型训练记录,前端界面,ui界面,各种指标图:包括准确率,精确率,召回率,F1值,损失曲线,准确率曲线等卷积模型采用vgg16模型或efficien......
  • 短期面试突击攻略大全!2025最全Java面试题目合集
     这两年的面试难度确实要比往年高处很多。很多小伙伴投递了上千份简历,只有几家公司约面试。排除个人简历的因素,这在往年都是不太常见的。大厂缩招,于是很多往年能进大厂的人只能去卷中小厂,搞得层层内卷。 比如往年能有一万个人能进大厂,今年大厂只招聘一千个,那另外九千个在往......
  • Apple礼品卡大合集
    1、什么是Apple礼品卡?苹果礼品卡分为两种:AppleStore礼品卡和APPStore礼品卡(也称为iTunes礼品卡)。苹果iTunes礼品卡代码是一串由16位数字和字母组成的字符,它可以被用来兑换iTunes礼品卡的金额。2、Apple礼品卡可以购买什么东西?AppleStore礼品卡‌:适用于苹果的线上商店和......