首页 > 其他分享 >PyTorch中F.cross_entropy()函数

PyTorch中F.cross_entropy()函数

时间:2022-11-10 10:36:59浏览次数:42  
标签:01 log pred torch cross PyTorch entropy soft


对PyTorch中F.cross_entropy()的理解

PyTorch提供了求交叉熵的两个常用函数:

一个是F.cross_entropy(),

另一个是F.nll_entropy(),

是对F.cross_entropy(input, target)中参数target讲解如下。
一、交叉熵的公式及计算步骤
1、交叉熵的公式:

H(p,q)=−i∑P(i)logQ(i)

其中 P P为真实值, Q Q为预测值。
2、计算交叉熵的步骤:
1)步骤说明:

①将predict_scores进行softmax运算,将运算结果记为pred_scores_soft;
②将pred_scores_soft进行log运算,将运算结果记为pred_scores_soft_log;
③将pred_scores_soft_log与真实值进行计算处理。
思路即:
                                                  scores→softmax→log→compute
2)举一个例子对计算进行说明:

P 1 = [ 1 0 0 0 0 ]

Q 1 = [ 0.4 0.3 0.05 0.05 0.2 ]

H ( p , q ) = − ∑ i P ( i ) log ⁡ Q ( i ) = − ( 1 ∗ l o g 0.4 + 0 ∗ l o g 0.3 + 0 ∗ l o g 0.05 + 0 ∗ l o g 0.05 + 0 ∗ l o g 0.2 ) = − l o g 0.4 ≈ 0.916
如果
Q 2 = [ 0.98 0.01 0 0 0.01 ]

H ( p , q ) = − ∑ i P ( i ) log ⁡ Q ( i ) = − ( 1 ∗ l o g 0.98 + 0 ∗ l o g 0.01 + 0 ∗ l o g 0.05 + 0 ∗ l o g 0 + 0 ∗ l o g 0.01 ) = − l o g 0.98 ≈ 0.02

由 H ( p , q )的计算结果和直观地观察 Q1​和 Q2​与 P1​的相似度,均可看出  Q2​比 Q1​更近似于 P 1
二、官方文档的说明

在PyTorch的官方中文文档中F.cross_entropy()的记录如下:

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=True)

该函数使用了 log_softmax 和 nll_loss,详细请看CrossEntropyLoss

常用参数:

参数名

shape


input

(N,C)

C是类别的个数

target

N

0 <= targets[i] <= C-1

三、自己的理解

在官方文档说明中,对于target参数的说明为,torch.shape为torch.Size([N]),0 <= targets[i] <= C-1。
网络计算输出并送入函数中的input的torch.shape为torch.Size([N,C]),它的torch.shape并不会因为softmax和log的操作而改变,但是target的torch.shape为torch.Size([N]),是一个标量而不是一个矩阵,那么如何按照上面的例子中的运算方法进行交叉熵的计算?

例如:

import torch
import torch.nn.functional as Fpred_score = torch.tensor([[13., 3., 2., 5., 1.],
[1., 8., 20., 2., 3.],
[1., 14., 3., 5., 3.]])
print(pred_score)
pred_score_soft = F.softmax(pred_score, dim=1)
print(pred_score_soft)
pred_score_soft_log = pred_score_soft.log()
print(pred_score_soft_log)

它的结果为:

tensor([[13.,  3.,  2.,  5.,  1.],
[ 1., 8., 20., 2., 3.],
[ 1., 14., 3., 5., 3.]])
tensor([[9.9960e-01, 4.5382e-05, 1.6695e-05, 3.3533e-04, 6.1417e-06],
[5.6028e-09, 6.1442e-06, 9.9999e-01, 1.5230e-08, 4.1399e-08],
[2.2600e-06, 9.9984e-01, 1.6699e-05, 1.2339e-04, 1.6699e-05]])
tensor([[-4.0366e-04, -1.0000e+01, -1.1000e+01, -8.0004e+00, -1.2000e+01],
[-1.9000e+01, -1.2000e+01, -6.1989e-06, -1.8000e+01, -1.7000e+01],
[-1.3000e+01, -1.5904e-04, -1.1000e+01, -9.0002e+00, -1.1000e+01]])

如何与一个标量target进行计算?
四、分析

F.Cross_entropy(input, target)函数中包含了softmax和log的操作,即网络计算送入的input参数不需要进行这两个操作。

例如在分类问题中,input表示为一个torch.Size([N, C])的矩阵,其中,N为样本的个数,C是类别的个数,input[i][j]可以理解为第 i 样本的类别为 jj的Scores,Scores值越大,类别为 j 的可能性越高,就像在代码块中所体现的那样。

同时,一般我们将分类问题的结果作为lable表示时使用one-hot embedding,例如在手写数字识别的分类问题中,数字0的表示为 [ 1 0 0 0 0 0 0 0 0 0 ]
数字3的表示为 [ 0 0 0 1 0 0 0 0 0 0 ]
在手写数字识别的问题中,我们计算 l o s s loss loss的方法为 l o s s = ( y − y ^ ) 2 ,即求  y的embedding的矩阵减去pred_probability矩阵的结果矩阵的范数。

但是在这里,交叉熵的计算公式为

H ( p , q ) = − ∑ i P ( i ) log ⁡ Q ( i )

其中 P 为真实值概率矩阵, Q为预测值概率矩阵。

那么如果 P使用one-hot embedding的话,只有在 i 为正确分类时 P ( i ) 才等于  1,否则, P ( i ) 等于0。
例如在手写数字识别中,数字3的one-hot表示为 [ 0 0 0 1 0 0 0 0 0 0 ]
对于交叉熵来说, H ( p , q ) = − ∑ i P ( i ) l o g Q ( i ) = − P ( 3 ) l o g Q ( 3 ) = − l o g Q ( 3 )

发现 H ( p , q ) 的计算不依赖于 P矩阵,而仅仅与 P的真实类别的index有关
五、总结

所以,我的理解是,在one-hot编码的前提下,在pytorch代码中target不需要以one-hot形式表示,而是直接用scalar,scalar的值则是真实类别的index。所以交叉熵的公式可表示为:
H ( p , q ) = − ∑ i P ( i ) l o g Q ( i ) = − P ( m ) l o g Q ( m ) = − l o g Q ( m )
其中, m m m表示真实类别。

标签:01,log,pred,torch,cross,PyTorch,entropy,soft
From: https://blog.51cto.com/u_13206712/5839950

相关文章

  • pytorch张量索引
    一、pytorch返回最值索引1官方文档资料1.1torch.argmax()介绍 返回最大值的索引下标函数:torch.argmax(input,dim,keepdim=False)→LongTensor返回值:Retur......
  • pytorch tensor 张量常用方法介绍
    1. view()函数PyTorch 中的view()函数相当于numpy中的resize()函数,都是用来重构(或者调整)张量维度的,用法稍有不同。>>>importtorch>>>re=torch.tensor([1,......
  • pytorch TensorDataset和DataLoader区别
    TensorDatasetTensorDataset可以用来对tensor进行打包,就好像python中的zip功能。该类通过每一个tensor的第一个维度进行索引。因此,该类中的tensor第一维度必须......
  • pytorch入门
    初衷:看不懂论文开源代码参考:B站小土堆(土堆yyds~)   PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili 1.环境配置参考:(39条消息)win10......
  • 一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】
    ????声明:作为全网AI领域干货最多的博主之一,❤️不负光阴不负卿❤️????深度学习:#超分重建、一文读懂????超分重建经典网络SRGAN详尽教程????最近更新:2022年2月28......
  • 使用PyTorch实现简单的AlphaZero的算法(1):背景和介绍
    在本文中,我们将在PyTorch中为ChainReaction[2]游戏从头开始实现DeepMind的AlphaZero[1]。为了使AlphaZero的学习过程更有效,我们还将使用一个相对较新的改进,称为“Playout......
  • Transfer-Meta Framework for Cross-domain Recommendation to Cold-Start Users阅读
    动机本文是2021年SIGIR上的一篇论文。本文主要针对的是冷启动问题中的跨域推荐问题,目前常用的方法是EMCDR,但是这个方法很大局限性,它仅在重叠的用户上学习,这样学到的模型会......
  • pandas df分段(cut)后交叉(crosstab)数据标签缺失的补充
    数值数据分类后交叉,但是数据量少,或者划分标准不科学导致分类的类别有缺失,交叉后会丧失类别,数据不齐整importnumpyasnpimportpandasaspddf=pd.DataFrame(n......
  • PyTorch实现非极大值抑制(NMS)
    NMS即nonmaximumsuppression即非极大抑制,顾名思义就是抑制不是极大值的元素,搜索局部的极大值。在最近几年常见的物体检测算法(包括rcnn、sppnet、fast-rcnn、faster-rcnn......
  • Pytorch中模型调用
    注意:RNN、LSTM的batch_first参数,对于不同的网络层,输入的维度虽然不同,但是通常输入的第一个维度都是batch_size,比如torch.nn.Linear的输入(batch_size,in_features),torch.nn......