首页 > 其他分享 >负对数似然(NLL)和困惑度(PPL)

负对数似然(NLL)和困惑度(PPL)

时间:2024-08-29 20:36:33浏览次数:4  
标签:似然 tensor torch PPL 类别 对数 logits NLL

让我们通过一个简单的例子来演示这段代码的计算过程,包括负对数似然(NLL)和困惑度(PPL)的计算。为了简化,我们将假设一个非常小的模型输出和数据。

假设:

  • 我们有两个样本(即 batch size 为 2)。
  • 每个样本有 3 个可能的类别,S_logits 是模型输出的 logits。
  • smask 是一个掩码,假设全部为 True,即我们对所有样本和所有类别都进行处理。
  • s_batch_id 是一个表示每个样本的索引的向量,用于 scatter_mean 的计算。

1. 模型输出的 logits:

假设 r_pred_S_logits 的最后一层输出如下(为了简单,假设只有一个时间步长):

import torch

# 假设的 logits
r_pred_S_logits = [torch.tensor([[[2.0, 1.0, 0.1], [1.0, 2.5, 0.5]]])]

# 掩码
smask = torch.tensor([True, True])

# 批次 ID(假设第一个样本和第二个样本)
s_batch_id = torch.tensor([0, 1])

2. 计算 softmax 概率分布:

首先,对 S_logits 进行 softmax 操作:

S_logits = r_pred_S_logits[-1][0][smask]  # shape: (2, 3)
S_dists = torch.softmax(S_logits, dim=-1)  # shape: (2, 3)
print(S_dists)

这将输出:

tensor([[0.6590, 0.2424, 0.0986],
        [0.2312, 0.6285, 0.1403]])

每一行是一个样本的概率分布。

3. 采样类别:

然后,从 S_dists 中使用 torch.multinomial 采样类别:

pred_S = torch.zeros_like(smask, dtype=torch.long)
pred_S[smask] = torch.multinomial(S_dists, num_samples=1).squeeze()
print(pred_S)

假设采样结果为:

tensor([0, 1])

这意味着第一个样本预测为类别 0,第二个样本预测为类别 1。

4. 计算 NLL:

我们从 S_dists 中提取出预测类别的概率,并计算负对数似然(NLL):

S_probs = S_dists[torch.arange(s_batch_id.shape[0]), pred_S[smask]]
print(S_probs)

假设输出为:

tensor([0.6590, 0.6285])

计算 NLL:

nlls = -torch.log(S_probs)
print(nlls)

输出:

tensor([0.4170, 0.4642])

这两个值是每个样本的 NLL 值。

5. 计算 PPL:

最后,通过 scatter_mean 计算每个 batch 的平均 NLL(在这个简单的例子中,每个样本有一个唯一的 ID,所以直接取平均值),然后困惑度(PPL)可以通过取指数得到:

from torch_scatter import scatter_mean

ppl = scatter_mean(nlls, s_batch_id)
print(ppl)

结果:

tensor([0.4170, 0.4642])

困惑度可以通过取指数计算得到:

ppl_value = torch.exp(ppl)
print(ppl_value)

结果:

tensor([1.5172, 1.5901])

总结

在这个简单的例子中:

  • NLL: 是 0.41700.4642,分别对应两个样本的负对数似然。
  • 困惑度(PPL): 通过 exp(NLL) 计算得出 1.51721.5901

这些值反映了模型预测的质量,困惑度越小,表示模型对样本的预测越好。

负对数似然是怎么计算的

负对数似然(Negative Log-Likelihood, NLL)是用来衡量模型预测的概率与真实类别之间的差异的一种损失函数。在分类任务中,它可以用于评估模型预测的准确性。NLL 的计算是基于模型输出的概率(通常通过 softmax 函数生成的概率分布)和真实类别标签的。

负对数似然的计算步骤

  1. 模型输出 logits:

    • 模型通常输出 logits,也就是未归一化的分数。这些分数代表了模型对每个类别的信心。
  2. Softmax 函数:

    • 将 logits 转换为概率分布。这一步骤使用 softmax 函数,使得每个类别的分数被归一化为一个概率值,所有类别的概率之和为 1。

    公式如下:

    [
    P(y_i) = \frac{\exp(\text{logit}i)}{\sum^{N} \exp(\text{logit}_j)}
    ]

    其中,( P(y_i) ) 是类别 ( y_i ) 的概率,logit 是模型输出的原始分数。

  3. 选择预测类别的概率:

    • 从 softmax 生成的概率分布中,选择实际发生的类别(或模型预测的类别)的概率。
  4. 计算负对数似然:

    • 取所选概率的负对数作为负对数似然值(NLL)。

    公式如下:

    [
    \text{NLL} = -\log(P(y_{\text{true}}))
    ]

    其中,( P(y_{\text{true}}) ) 是真实类别的预测概率。

例子

假设我们有一个三分类问题,模型输出的 logits 为:

import torch

logits = torch.tensor([2.0, 1.0, 0.1])

1. Softmax 计算概率分布:

probs = torch.softmax(logits, dim=-1)
print(probs)

这将输出:

tensor([0.6590, 0.2424, 0.0986])

即,类别 0 的概率是 0.6590,类别 1 的概率是 0.2424,类别 2 的概率是 0.0986。

2. 假设真实类别是 0,那么选择类别 0 的概率:

P_true = probs[0]
print(P_true)

输出:

tensor(0.6590)

3. 计算负对数似然:

nll = -torch.log(P_true)
print(nll)

输出:

tensor(0.4170)

这个值 ( 0.4170 ) 就是类别 0 的负对数似然,它反映了模型对这个类别的预测质量。

总结

  • 负对数似然(NLL) 是模型对某个类别预测概率的负对数。
  • NLL 越小,说明模型对真实类别的预测概率越高,模型的表现越好。
  • NLL 越大,说明模型对真实类别的预测概率越低,模型的表现越差。

标签:似然,tensor,torch,PPL,类别,对数,logits,NLL
From: https://www.cnblogs.com/csjywu01/p/18387519

相关文章

  • SAP: 如何运行Web Dynpro Application
     照着《SAPWebDynproForABAP开发技术详解 高端应用》学习抄例子,没有看过初级应用篇直接来学习高端应用知道会有很多知识断了。抄完例子后,不知道如何运行下面补充一下运行操作步骤: 一、创建一个WebDynproApplication,浏览器访问WebDynpro应用。 选择WebDynproC......
  • [Azure Application Insights]Azure应用程序见解概述页面中workspace的link不见了?
    问题描述在AzureApplicationInsights的概述页面中,可以直接点击WorkspaceLink进入到Workspace资源页面。但是,在下面的示例图中,WorkspaceLink不见了?这是什么原因呢? 问题解答这是因为Workspace的资源组发生了改变。ApplicationInsights无法根据WorksapceResour......
  • Spring Boot 框架中配置文件 application.properties 当中的所有配置大全
    SpringBoot框架中配置文件application.properties当中的所有配置大全#SPRINGCONFIG(ConfigFileApplicationListener)spring.config.name=#配置文件名(默认为'application')spring.config.location=#配置文件的位置#多环境配置文件激活属性spring.profiles.active......
  • Android开发 - Application 基础类全局的应用级状态管理解析
    Application是什么Application是一个基础类,用于全局的应用级状态管理。它在应用程序启动时被创建,并在应用程序关闭时销毁。Application对象的生命周期与应用程序的生命周期一致,因此它非常适合用来保存全局的应用状态信息或初始化全局资源Application的主要作用全局状态管......
  • A review of ssm and their applications in connectedand automated vehicles safety
    ABSTRACTSurrogateSafetyMeasures(SSM)areimportantforsafetyperformanceevaluation,since crashesarerareeventsandhistoricalcrashdatadoesnotcapturenearcrashesthatarealsocriticalforimprovingsafety.Thispaper focusesonSSMandthei......
  • 苹果 exchange apples
    萧伯纳的那句:你有一个苹果,我有一个苹果,我们彼此交换.全句的英文原文是什么?扫码下载作业帮搜索答疑一搜即得答案解析 查看更多优质解析 解答一举报IfyouhaveanappleandIhaveanapple,andweexchangeapples,webothstillonlyhaveoneappl......
  • 修改SpringBoot的配置文件application.yaml后启动失败
    经常碰到修改application.yaml文件之后,SpringBoot项目启动失败的,报错信息如下ConnectedtothetargetVM,address:'127.0.0.1:7105',transport:'socket'21:12:59.122[main]DEBUGorg.springframework.boot.context.logging.ClasspathLoggingApplicationListener-App......
  • Python3.11二进制AI项目程序打包为苹果Mac App(DMG)-应用程序pyinstaller制作流程(App
    众所周知,苹果MacOs系统虽然贵为Unix内核系统,但由于系统不支持N卡,所以如果想在本地跑AI项目,还需要对相关的AI模块进行定制化操作,本次我们演示一下如何将基于Python3.11的AI项目程序打包为MacOS可以直接运行的DMG安装包,可以苹果系统中一键运行AI项目。MacOs本地部署AI项目首先确......
  • 017、二级Java知识点之Java Applet与图像处理:从基础到应用
    JavaApplet与图像处理:从基础到应用1.题目解析先来详细解析题目中给出的代码示例:importjava.____.*;importjava.awt.*;importjava.net.*;publicclassImageDemoextendsApplet{privateImageimage;publicvoid______(){Stringimage......
  • Android Kotlin优化代码整洁:with、applay以及run是什么,作用,区别
    目录为什么需要使用with、applay以及runwith、applay以及run是什么、作用以及三者区别一、为什么需要使用with、applay以及run我们在开发项目的过程当中,不可避免,一个界面的内容会有很多,如下initivew方法,会有Recyclerview的初始化,长按以及触摸事件设置,以及生命周期的注册,主......