首页 > 其他分享 >nn.MarginRankingLoss介绍

nn.MarginRankingLoss介绍

时间:2023-01-11 12:34:09浏览次数:50  
标签:loss input1 target nn 介绍 reduction MarginRankingLoss margin

nn.MarginRankingLoss

复现论文代码中,它使用了MarginRankingLoss()函数,以下是我百度的内容:

排序损失函数

对于包含\(\mathbf{N}\)个样本的batch数据 \(D(x_1,x_2,y)\), \(x_1\),\(x_2\)是给定的待排序的两个输入,\(y\)代表真实的标签,属于{ 1 , − 1 } 。当Y = 1 是,\(x_1\)应该排在\(x_2\)前,Y = − 1 是,\(x_1\)应该排在\(x_2\)之后。

第n个样本对应的loss计算如下:

\[l_n = \max(0,-y*(x_1-x_2)+margin) \]

若\(x_1\),\(x_2\)排序正确且\(-y*(x_1-x_2)>margin\),则loss为0

class MarginRankingLoss(_Loss):
    __constants__ = ['margin', 'reduction']
    def __init__(self, margin=0., size_average=None, reduce=None, reduction='mean'):
        super(MarginRankingLoss, self).__init__(size_average, reduce, reduction)
        self.margin = margin
    def forward(self, input1, input2, target):
        return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)

pytorch中通过torch.nn.MarginRankingLoss类实现,也可以直接调用F.margin_ranking_loss 函数,代码中的size_average与reduce已经弃用。reduction有三种取值mean, sum, none,对应不同的返回ℓ ( x , y )。 默认为mean,对应于上述loss的计算

\[L=\{l_1,\dots, l_N\} \]

\[\ell(x, y)= \begin{cases}\mathrm{L}, & \text { if reduction = 'none' } \\ \frac{1}{N} \sum_{i=1}^{N} l_{i}, & \text { if reduction = 'mean' } \\ \sum_{i=1}^{N} l_{i} & \text { if reduction = 'sum' }\end{cases} \]

margin默认取0

例子:

import torch
import torch.nn.functional as F
import torch.nn as nn
import math

def validate_MarginRankingLoss(input1, input2, target, margin):
    val = 0
    for x1, x2, y in zip(input1, input2, target):
        loss_val = max(0, -y * (x1 - x2) + margin)
        val += loss_val
    return val / input1.nelement()

torch.manual_seed(10)
margin = 0
loss = nn.MarginRankingLoss()
input1 = torch.randn([3], requires_grad=True)
input2 = torch.randn([3], requires_grad=True)
target = torch.tensor([1, -1, -1])
print(target)
output = loss(input1, input2, target)
print(output.item())

output = validate_MarginRankingLoss(input1, input2, target, margin)
print(output.item())

loss = nn.MarginRankingLoss(reduction="none")
output = loss(input1, input2, target)
print(output)

'''
tensor([ 1, -1, -1])
0.015400052070617676
0.015400052070617676
tensor([0.0000, 0.0000, 0.0462], grad_fn=<ClampMinBackward>)
'''

标签:loss,input1,target,nn,介绍,reduction,MarginRankingLoss,margin
From: https://www.cnblogs.com/jev-0987/p/17043369.html

相关文章

  • HAL库教程1:STM32Cube的介绍
      使用STM32HAL库已经有了一段时间,觉得相比于标准库,好用了不少。加上STM32CubeMX图形化配置工具的加持,个人认为可以极大提升开发效率。其实关于HAL库的教程已经很多了,关于......
  • VS2022 17.1.6在windows10下打开winform设计器报timed out while connecting to named
    .net6.0的项目,vs202217.1.6在windows10下打开winform设计器报timedoutwhileconnectingtonamedpipe错误,同样的项目在windows7却可以打开winform设计器,很奇怪。N多......
  • C#中内联函数的用法介绍
    函数调用在执行时,首先要在栈中为形参和局部变量分配存储空间,然后还要将实参的值复制给形参,接下来还要将函数的返回地址(该地址指明了函数执行结束后,程序应该回到哪里继续执......
  • Spring 中的Advice类型介绍
    Spring中的Advice类型介绍翻译原文链接IntroductiontoAdviceTypesinSpring1.概述在本文中,我们将讨论可以在Spring中创建的不同类型的AOP通知。Inthisa......
  • TiDB 底层存储结构 LSM 树原理介绍
    作者:京东给物流刘家存随着数据量的增大,传统关系型数据库越来越不能满足对于海量数据存储的需求。对于分布式关系型数据库,我们了解其底层存储结构是非常重要的。本文将介绍......
  • TiDB 底层存储结构 LSM 树原理介绍
    作者:京东给物流刘家存随着数据量的增大,传统关系型数据库越来越不能满足对于海量数据存储的需求。对于分布式关系型数据库,我们了解其底层存储结构是非常重要的。本文将介......
  • HLS协议介绍及点播实现原理
    HTTPLiveStreaming(缩写是HLS)是一个由苹果公司提出的基于HTTP的流媒体网络传输协议。是苹果公司QuickTimeX和iPhone软件系统的一部分。它的工作原理是把整个流分成一个......
  • Ubuntu安装easyconnect
    使用easyconnect的deb包下载点击图标没反应。通过命令行启动报错/usr/share/sangfor/EasyConnect/EasyConnect[1]    57076 segmentation fault (core dumped) ......
  • C++ 编译依赖管理系统分析以及 srcdep 介绍
    C++编译依赖管理系统分析以及srcdep介绍如果用C++写一个中小型软件,有要用到很多第三方库的话,相信不少人会觉得比较麻烦。很多新兴的语言都有了统一的依赖管理系统和......
  • oracle 多行合并成一行: listagg within group CONNECT BY 可以和递归方法一起使用查
    oracle多行合并成一行:listaggwithingroupCONNECTBY可以和递归方法一起使用查询路径:https://www.bbsmax.com/A/A7zgpjGYJ4/oracle多行合并成一行:listaggwit......