首页 > 其他分享 >torch.nn.CrossEntropyLoss

torch.nn.CrossEntropyLoss

时间:2022-12-02 17:07:31浏览次数:42  
标签:loss attr nn torch reduction cross CrossEntropyLoss entropy


文章目录

  • ​​交叉熵损失函数`torch.nn.CrossEntropyLoss`​​
  • ​​F.cross_entropy​​
  • ​​F.nll_loss​​

交叉熵损失函数​​torch.nn.CrossEntropyLoss​

  • weight (Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size ​​C​​ 每个类别计算损失的权重
  • size_average (bool, optional): Deprecated (see :attr:​​reduction​​​). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the field :attr:​​size_average​​​ is set to ​​False​​​, the losses are instead summed for each minibatch. Ignored when reduce is ​​False​​​. Default: ​​True​
  • ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. When ​​size_average​​​ is ​​True​​, the loss is averaged over non-ignored targets.
  • reduce (bool, optional): Deprecated (see :attr:​​reduction​​​). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:​​size_average​​​. When :attr:​​reduce​​​ is ​​False​​​, returns a loss per batch element instead and ignores :attr:​​size_average​​​. Default: ​​True​
  • reduction (string, optional): Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’.
  • ‘none’: no reduction will be applied,
  • ‘mean’: the sum of the output will be divided by the number of elements in the output
  • ‘sum’: the output will be summed.
  • Note: :attr:​​size_average​​​ and :attr:​​reduce​​​ are in the process of being deprecated, and in the meantime,specifying either of those two args will override :attr:​​reduction​​. Default: ‘mean’

简单来说,三个参数:​​weight​​​、​​ignore​​​、​​reduction​

  • ​weight​​调整每个类别的权重
  • ​ignore_index​​不计算损失的index,例如padding的index不计算损失
  • ​reduction​​控制loss的计算模式[none,mean,sum]
import torch.nn.functional as F
input = torch.randn(3,5)
label = torch.empty(3, dtype=torch.long).random_(5) # -> tensor([1, 3, 0])

res = F.cross_entropy(input, label)
>>> tensor(1.8942)
res_mean = F.cross_entropy(input, label, reduction='mean')
>>> tensor(1.8942)
res_sum = F.cross_entropy(input, label, reduction='sum')
>>> tensor(5.6826)
res_none = F.cross_entropy(input, label, reduction='none')
>>>tensor([1.3254, 2.9982, 1.3590])
res_ignore0 = F.cross_entropy(input, label, reduction='none', ignore_index=0)
>>>tensor([1.3254, 2.9982, 0.0000])

F.cross_entropy

​torch.nn.CrossEntropyLoss​​​调用了函数​​F.cross_entropy​​​,与tf中不同的是,​​F.cross_entropy​​​执行包含两部分​​log_softmax​​​和​​F.nll_loss​​​​log_softmax​​主要用于解决函数overflow和underflow,加快运算速度,提高数据稳定性。
softmax会进行指数操作,当输入比较大,会产生overflow;当输入为负数且绝对值也很大,会使得分子和分母很小,有可能四舍五入向下溢出。
在数学表达式是对softmax取对数,实际运算是通过下列式子:
torch.nn.CrossEntropyLoss_pytorch
其中,M为所有torch.nn.CrossEntropyLoss_ide_02中最大的值。

F.nll_loss

​F.nll_loss​​​表示​​The negative log likelihood loss.​​​log似然代价函数
​log_softmax与softmax的区别在哪里?pytorch的F.cross_entropy交叉熵函数


标签:loss,attr,nn,torch,reduction,cross,CrossEntropyLoss,entropy
From: https://blog.51cto.com/u_15899958/5907215

相关文章

  • Pytorch mask:上三角和下三角
    上三角triuPytorch上三角和下三角的调用与numpy是相同的。np.triu(np.ones((5,5)),k=0)#k控制对角线开始的位置Out[25]:array([[1.,1.,1.,1.,1.],[0.,1.,1......
  • 7.2 a single layer of GNN
    AsinglelayerofGNN1.IdeaofaGNNLayer:CompressasetofvectorsintoasinglevectorTwostepprocess:1.Message2.aggregation(1)messagecomputa......
  • ReactHook父组件调用子组件的方法,且子组件用了connect
    ReactHook父组件调用子组件的方法,且子组件用了connect子组件1、引入useImperativeHandle,forwardRef2、子组件由function改成let,接收prop和ref,并从props中结构出refI......
  • annotate和aggregate的区别
    一.基本区别aggregate:返回使用聚合函数后的字段和值。annotate:在原来模型字段的基础之上添加一个使用了聚合函数的字段二.使用方法classBook(models.Model):......
  • 如何优雅的关闭channel?
    一、channel使用存在的不方便地方1、在不改变channel自身状态的情况下,无法获知一个channnel是否关闭。2、关闭一个已经关闭的channel,会导致panic。因此,如果关闭channel的......
  • Vue3中Echart挂在全局报错问题 dataSample.js:104 Uncaught TypeError: Cannot read p
    原因Proxy应用到了整个ECharts实例上的问题,不太建议把整个ECharts实例这样的对象放到ref里,容易影响到实例底层的运行。可以使用shallowRef替代,这样Proxy不会应......
  • mysql innodb中的两类索引
    mysql的innodb中有两类索引,分别是Cluster形式的主键索引(PrimaryKey),另外一种则是和其他存储引擎(如MyISAM存储引擎)存放形式基本相同的普通B-Tree......
  • java解加密(AES/CBC)异常:java.lang.SecurityException: JCE cannot authenticate the
    原文链接:https://blog.csdn.net/weixin_43048843/article/details/109200673对接第三方厂商需求时,需要对数据AES256进行解密,由于java本身不支持,需要添加依赖。一、版本适......
  • FasterRcnn
    #FasterRCNN*原始版本*https://github.com/rbgirshick/py-faster-rcnn*论文*http://arxiv.org/abs/1506.01497*比较好的文章*https://zhuanlan.......
  • 矩池云 | GPU 分布式使用教程之 Pytorch
    GPU分布式使用教程之PytorchPytorch官方推荐使用DistributedDataParallel(DDP)模块来实现单机多卡和多机多卡分布式计算。DDP模块涉及了一些新概念,如网络(WorldSize......