目录
一、常用损失函数
1、CrossEntropyLoss(交叉熵损失)
1)原理
交叉熵损失是一种常用于分类问题的损失函数,它衡量的是模型输出的概率分布与真实标签分布之间的差异
在多分类问题中,模型会输出每个类别的预测概率。交叉熵损失通过计算真实标签对应类别的负对数概率来评估模型的性能。在实际应用中,nn.CrossEntropyLoss
内部会对logits(即未经softmax的原始输出)应用softmax函数,将其转换为概率分布,然后计算交叉熵。
例如:
假设有一个多类别分类任务,共有C个类别。对于每个样本,模型会输出一个包含C个元素的向量,其中每个元素表示该样本属于对应类别的概率。而真实标签是一个C维的向量,其中只有一个元素为1,其余元素均为0,表示样本的真实类别。
2)流程
首先,将模型输出的向量通过softmax函数进行归一化,将原始的概率值转换为概率分布。即对模型输出的每个元素进行指数运算,然后对所有元素求和,最后将每个元素除以总和,得到归一化后的概率分布。
然后,将归一化后的概率分布与真实标签进行比较,计算两者之间的差异。交叉熵损失函数的计算公式为: -sum(y * log(p)) ,其中y是真实标签的概率分布,p是模型输出的归一化后的概率分布。该公式表示真实标签的概率分布与模型输出的归一化后的概率分布之间的交叉熵。
最后,将每个样本的交叉熵损失值进行求和或平均,得到整个批次的损失值。
3)用法示例
import torch
import torch.nn as nn
# 假设有一个模型输出的logits和一个真实的标签
logits = torch.randn(10, 5, requires_grad=True) # 10个样本,5个类别
labels = torch.randint(0, 5, (10,)) # 真实标签,每个样本对应一个类别索引
# 创建CrossEntropyLoss实例
loss_fn = nn.CrossEntropyLoss()
# 计算损失
loss = loss_fn(logits, labels)
# 反向传播
loss.backward()
2、L1Loss(L1损失/平均绝对误差)
1)原理
L1损失,也称为平均绝对误差(MAE),计算的是预测值与真实值之差的绝对值的平均值。
L1损失对异常值(即远离平均值的点)的敏感度较低,因为它通过绝对值来度量误差,而绝对值函数在零点附近是线性的。
2)用法示例
loss_fn = nn.L1Loss()
predictions = torch.randn(3, 5, requires_grad=True) # 预测值
targets = torch.randn(3, 5) # 真实值
# 计算损失
loss = loss_fn(predictions, targets)
# 反向传播
loss.backward()
3、NLLLoss(负对数似然损失)
1)原理
负对数似然损失(NLLLoss)通常与log_softmax一起使用,用于多分类问题。它计算的是目标类别的负对数概率。
NLLLoss期望的输入是对数概率(即已经通过log_softmax处理过的输出),然后计算目标类别的负对数概率。
2)用法示例
# 假设已经计算了logits
logits = torch.randn(3, 5, requires_grad=True)
# 应用log_softmax获取对数概率(在PyTorch中,通常直接使用CrossEntropyLoss)
log_probs = torch.log_softmax(logits, dim=1)
# 创建NLLLoss实例
loss_fn = nn.NLLLoss()
# 真实标签
labels = torch.tensor([1, 0, 4], dtype=torch.long)
# 计算损失
loss = loss_fn(log_probs, labels)
# 反向传播
loss.backward()
在实际应用中,直接使用CrossEntropyLoss
更为常见,因为它内部集成了softmax和NLLLoss的计算。
4、 MSELoss(均方误差损失)
1)定义
均方误差损失(MSE)计算的是预测值与真实值之差的平方的平均值。
MSE通过平方误差来放大较大的误差,从而给予模型更大的惩罚。它是回归问题中最常用的损失函数之一。
2)用法示例
loss_fn = nn.MSELoss()
predictions = torch.randn(3, 5, requires_grad=True) # 预测值
targets = torch.randn(3, 5) # 真实值
# 计算损失
loss = loss_fn(predictions, targets)
# 反向传播
loss.backward()
5.BCELoss(二元交叉熵损失)
1)定义
二元交叉熵损失(BCE)用于二分类问题,计算的是预测概率与真实标签(0或1)之间的交叉熵。
BCE通过计算真实标签对应类别的负对数概率来评估模型的性能。它适用于输出概率的模型,但并不要求输入必须经过sigmoid函数(尽管在实践中很常见)。
2)用法示例
loss_fn = nn.BCELoss()
# 假设预测值已经通过sigmoid函数(虽然不是必需的)
predictions = torch.sigmoid(torch.randn(3, requires_grad=True))
# 真实标签
targets = torch.empty(3).random_(2).float() # 生成0或1的随机值
# 计算损失
loss = loss_fn(predictions, targets)
# 反向传播
loss.backward()
二、总结常用损失函数
1、nn.CrossEntropyLoss:交叉熵损失函数
主要用于多分类问题。它将模型的输出(logits)与真实标签进行比较,并计算损失。
2、nn.MSELoss:均方误差损失函数
用于回归问题。它计算模型输出与真实标签之间的差异的平方,并返回平均值。
3、nn.L1Loss:平均绝对误差损失函数
也称为L1损失。类似于MSELoss,但是它计算模型输出与真实标签之间的差异的绝对值,并返回平均值。
4、nn.BCELoss:二元交叉熵损失函数
用于二分类问题。它计算二分类问题中的模型输出与真实标签之间的差异,并返回损失。
5、nn.NLLLoss:负对数似然损失函数
主要用于多分类问题。它首先应用log_softmax函数(log_softmax(x) = log(softmax(x)))将模型输出转化为对数概率,然后计算模型输出与真实标签之间的差异。
标签:loss,函数,nn,标签,torch,损失,用法,PyTorch,解析 From: https://blog.csdn.net/qq_64603703/article/details/142343950