首页 > 其他分享 >Pytorch的cross_entropy为什么等于log_softmax加nll_loss

Pytorch的cross_entropy为什么等于log_softmax加nll_loss

时间:2023-01-12 20:23:34浏览次数:49  
标签:loss tensor 0.1 torch cross 0.7 log

首先我们要知道nll_loss是怎么算的,看下面的代码

label1 = torch.tensor([0, 3])
pred1 = torch.tensor([
    [0.2, 0.7, 0.8, 0.1],
    [0.1, 0.3, 0.5, 0.7]
    ])
loss = F.nll_loss(pred1, label1)
print(loss)  # 输出 tensor(-0.4500)

如何理解上面的代码?首先明确这是一个分类任务,总共有4个类,上面的代码计算了两条数据,可以认为是bachSize = 2。

  预测为第0类的概率 预测为第1类的概率 预测为第2类的概率 预测为第3类的概率
第1条数据 0.2 0.7 0.8 0.1
第2条数据 0.1 0.3 0.5 0.7

再具体点,每一条数据可以认为是一张图片,每一个类可以认为是该图片是鸡、鸭、鱼、鹅四种动物的概率。label1 = [0, 3]表示两条数据分别属于第0类和第3类,相当于下面的情况。

第1条数据 第0类 第1类 第2类 第3类
预测概率 0.2 0.7 0.8 0.1
实际概率 1 0 0 0

 

第2条数据 第0类 第1类 第2类 第3类
预测概率 0.1 0.3 0.5 0.7
实际概率 0 0 0 1

现在对问题的定义应该比价清楚了,接下来是nll_loss怎么算的,用公式不太好写,这里就用文字描述了:真实类别的预测概率的平均值乘负一。两条数据的真实标签分别是第0类和第3类,相应的预测概率分别为0.2和0.7,平均值为0.45,再乘负一,得0.45,与程序输出情况一致。其中求平均值是因为程序默认reduction='mean'

可以看出来nll_loss只能求每条数据只属于一个类别的情况(我目前理解是这样的),不能出现一条数据既属于第0类,又属于第1类。

同样适用上面的数据,我们计算cross_entropy

label3 = torch.tensor([
    [1, 0, 0, 0],
    [0, 0, 0, 1]
], dtype = torch.float32)
pred3 = torch.tensor([
    [0.2, 0.7, 0.8, 0.1],
    [0.1, 0.3, 0.5, 0.7]
    ])
loss = F.cross_entropy(pred3, label3)
print(loss)  # 输出 tensor(1.3965)

上面的label3代表数据是每个类别的真实概率是多少,跟上面的两个表格一样。label3 也可以用indices(也就是指明属于哪个类别),即 label3 = torch.tensor([0,3]),两者是等价的。

下面探讨如何用log_softmax和nll_loss组合出cross_entropy,代码如下:

label2 = torch.tensor([0, 3])
pred2 = torch.tensor([
    [0.2, 0.7, 0.8, 0.1],
    [0.1, 0.3, 0.5, 0.7]
    ])
pred2 = F.log_softmax(pred2, dim = 1)    # dim = 1是横着四个元素和为1, dim = 0是竖着两个元素和为1
loss = F.nll_loss(pred2, label2)
print(loss)  # 输出 tensor(1.3965)

这里比第一次的代码多了一句 pred2 = F.log_softmax(pred2, dim = 1)。log_softmax的意思是先softmax,再log(实际是ln,以e为底的log)。log用来保证最终结果为正(softmax压缩到区间[0,1])

为了更深刻的理解,我们接下来手算一下。

原始数据1 0.2 0.7 0.8 0.1
softmax后 0.1860 0.3067 0.3390 0.1683
log后(ln) -1.682 -1.1818 -1.0817 -1.7820

 

原始数据2 0.1 0.3 0.5 0.7
softmax后 0.1807 0.2207 0.2695 0.3292
log后(ln) -1.7109 -1.5109 -1.3111 -1.1110

按照nll_loss的计算方法:真实类别的预测概率的平均值乘负一:-1 * (-1.682 + -1.1110) / 2 = 1.3965,与程序输出结果一致。

如果cross_entropy时,每条数据可以同时属于多个类别,又该如何计算呢?如下面的代码,第一条数据同时属于0,1类别,第二条数据同时属于2,3类别。

label4 = torch.tensor([
    [1, 1, 0, 0],
    [0, 0, 1, 1]
], dtype = torch.float32)
pred4 = torch.tensor([
    [0.2, 0.7, 0.8, 0.1],
    [0.1, 0.3, 0.5, 0.7]
    ])
loss = F.cross_entropy(pred4, label4)
print(loss)  # 输出 tensor(2.6430)

明天再写

标签:loss,tensor,0.1,torch,cross,0.7,log
From: https://www.cnblogs.com/roadwide/p/17047786.html

相关文章

  • m在ISE平台下使用verilog开发基于FPGA的GMSK调制器
    1.算法描述       高斯最小频移键控(GaussianFilteredMinimumShiftKeying),这是GSM系统采用的调制方式。数字调制解调技术是数字蜂窝移动通信系统空中接口的重要......
  • log4j2.xml配置自定义参数 日志变量打印,比如全局traceId
    1、在拦截器中设置MDC的变量packagecom.sleep.demo.intercepter;importlombok.extern.slf4j.Slf4j;importorg.apache.commons.lang3.StringUtils;importorg.sl......
  • 打印日志 log4j
    前言**测试程序是经常需要打印后台执行日志,判断问题**<!--https://mvnrepository.com/artifact/log4j/log4j--><dependency><groupId>log4j</groupId><ar......
  • macOS 13 login items notifications close bug All In One
    macOS13loginitemsnotificationsclosebugAllInOne关闭后,不生效bugs太多次的重复通知出现,遮挡屏幕,影响正常使用无法批量一键清除通知,要一个一个的滑动......
  • .Net Core Logging模块源码阅读
    .NetCoreLogging模块源码阅读前言在Asp.NetCoreWebapi项目中经常会用到ILogger,于是在空闲的时候就clone了一下官方的源码库下来研究,这里记录一下。官方库地址在:h......
  • graylog docker-compose 安装yaml
    graylog是一款日志工具docker-compose部署version:'3'services:#MongoDB:https://hub.docker.com/_/mongo/mongo:image:mongo:5.0.13networks:......
  • 03-Verilog语法
    Verilog语法1Register组合逻辑-->寄存器-->组合逻辑-->寄存器Register是一个变量,用于存储值,并不代表一个真正的硬件DFF。regA,C;//assignmentsarealwaysdonei......
  • glog
    //#defineGOOGLE_STRIP_LOG1#include<glog/logging.h>#include<iostream>intmain(intargc,char**argv){FLAGS_log_dir="./log_dir";FLAGS_alsologt......
  • log4j2.xml
    <?xmlversion="1.0"encoding="utf-8"?><configuration><properties><!--文件输出格式--><propertyname="PATTERN">%d{yyyy-MM-ddHH:mm:ss......
  • nn.MarginRankingLoss介绍
    nn.MarginRankingLoss复现论文代码中,它使用了MarginRankingLoss()函数,以下是我百度的内容:排序损失函数对于包含\(\mathbf{N}\)个样本的batch数据\(D(x_1,x_2,y)\),\(x......