首页 > 其他分享 >correct = pred.eq(labels).sum() 的解读

correct = pred.eq(labels).sum() 的解读

时间:2024-10-21 21:18:04浏览次数:8  
标签:torch tensor pred sum labels 张量 True eq

        correct = pred.eq(labels).sum()怕是深度学习demo中最常见的代码了,eq()和sum()都是python中很常用的函数,但是这里的都是prtorch里面的函数,与python中的还是有一些区别的。

python中的用法 

        python中的eq()的典型用法:

from operator import eq

a = [1, 2, 3, 4, 1]
b = [1, 8, 3, 0, 5]

print(a.__eq__(b))   # False
print(eq(a, b))   # False

        python默认的__eq__() 方法返回一个布尔值(True 或 False),表示两个对象是否相等。 __eq__() 方法通常与 __ne__()(不相等)方法一起使用,以便提供完整的相等和不相等比较。
        operator 的eq()函数,实际上是在调用对象的 __eq__() 方法,也就是说:eq(a, b)与a.__eq__(b)是等价的。

        python中的sum ()函数

        “sum ()” 是 Python 中的一个内置函数,用于对可迭代对象(如列表、元组等)中的元素进行求和操作。例如,sum ([1, 2, 3]) 的结果是 6。它接受一个可迭代对象作为参数,并返回这些元素的总和。如果可迭代对象为空,sum () 通常返回 0。

pytorch中的用法 

        pytorch中的eq()函数:

        在 PyTorch 中,eq() 函数是 PyTorch 中实现张量比较的常用函数之一。eq() 函数用于逐元素地比较两个张量(tensor)的对应元素是否相等。如果两个元素相等,则返回 True,否则返回 False。
        demo:

import torch

a = torch.tensor([1, 2, 3, 4, 1])   # 将列表转换为tensor
b = torch.tensor([1, 8, 3, 0, 5])

print(torch.eq(a, b))      # 逐元素比较,返回一个bool类型的tensor

# tensor([ True, False,  True, False, False])

         以及:

import torch

a = torch.tensor([[1, 2, 3, 4, 1], [2, 3, 5, 7, 8]])   # 将列表转换为tensor
b = torch.tensor([[1, 8, 3, 0, 5], [2, 3, 6, 7, 9]])

print(torch.eq(a, b))      # 逐元素比较,返回一个bool类型的tensor

# tensor([[ True, False,  True, False, False],
#         [ True,  True, False,  True, False]])

      pytorch中的eq()函数详解:

        函数定义   
torch.eq(input, other, *, out=None) -> Tensor
        参数说明

        input (Tensor): 要比较的第一个张量。
        other (Tensor or float): 要比较的第二个张量或标量。
        out (Tensor, optional): 输出张量,如果指定,结果将写入这个张量。

        返回值

        返回一个布尔类型的张量,张量的形状与 input相同,对应位置的元素为 True 或 False。

        用法示例

        示例 1: 比较两个张量

import torch

# 定义两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 4, 3])

# 使用 eq() 函数比较
result = torch.eq(tensor1, tensor2)

print(result)

# tensor([ True, False,  True])


        示例 2: 比较张量和标量

import torch

# 定义一个张量和一个标量
tensor1 = torch.tensor([1, 2, 3])
scalar = 2

# 使用 eq() 函数比较
result = torch.eq(tensor1, scalar)

print(result)
# tensor([False,  True, False])


        示例 3: 使用 out 参数

import torch

# 定义两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 4, 3])

# 定义输出张量
output = torch.empty_like(tensor1, dtype=torch.bool)

# 使用 eq() 函数比较,并将结果写入输出张量
torch.eq(tensor1, tensor2, out=output)

print(output)

# tensor([ True, False,  True])

        注意事项

        形状匹配:如果 input 和 other 是张量,它们的形状必须是可广播的(broadcastable)。这意味着它们的形状要么相同,要么可以通过广播规则扩展为相同的形状。

        数据类型:eq() 函数比较时会考虑数据类型。例如,比较一个 float 类型的张量和一个 int 类型的张量可能会导致不准确的结果,因此建议确保比较的张量具有相同的数据类型。

        性能优化:为了提高性能,可以使用 out 参数预先分配输出张量,这样可以避免在计算过程中频繁地分配和释放内存。

        总结

        torch.eq() 是 PyTorch 中用于逐元素比较两个张量是否相等的函数。它返回一个布尔张量,指示每个位置的元素是否相等。这个函数在处理需要比较和条件判断的张量操作时非常有用。通过合理使用 out 参数,还可以优化性能。

pytorch中的sum()函数:

        在 PyTorch 中,sum() 函数用于计算张量中所有元素的总和。它可以用于一维、二维或多维张量,并且可以选择沿着特定的维度进行求和。sum() 函数非常灵活,适用于多种场景,如计算损失函数、统计等。

      pytorch中的sum()函数详解:

        函数定义
torch.sum(input, dim=None, keepdim=False, dtype=None) -> Tensor
        参数说明

        input (Tensor): 输入张量,即需要进行求和的张量。
        dim (int or tuple of ints, optional): 要进行求和的维度。如果指定,将对指定维度进行求和。如果未指定,将对所有元素进行求和。
        keepdim (bool, optional): 是否保持求和后的维度。默认值为 False,即求和后的维度会被压缩(减少)。
        dtype (torch.dtype, optional): 指定输出张量的数据类型。如果未指定,将使用输入张量的数据类型。

        返回值

        返回一个张量,包含求和结果。如果指定了 dim,返回的张量将减少相应的维度。

        用法示例


        示例 1: 计算整个张量的总和

import torch

# 定义一个张量
tensor = torch.tensor([1, 2, 3, 4, 5])

# 计算整个张量的总和
total_sum = torch.sum(tensor)

print(total_sum)

# tensor(15)


        示例 2: 沿着特定维度求和

import torch

# 定义一个二维张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

print(tensor.shape)   # torch.Size([2, 3])

# 沿着第0维(行)求和
row_sum = torch.sum(tensor, dim=0)

# 沿着第1维(列)求和
col_sum = torch.sum(tensor, dim=1)

print("Row sum:", row_sum)   # Row sum: tensor([5, 7, 9])
print("Column sum:", col_sum)   # Column sum: tensor([ 6, 15])


        示例 3: 保持维度求和

import torch

# 定义一个二维张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 沿着第0维(行)求和,并保持维度
row_sum_keepdim = torch.sum(tensor, dim=0, keepdim=True)

# 沿着第1维(列)求和,并保持维度
col_sum_keepdim = torch.sum(tensor, dim=1, keepdim=True)

print("Row sum (keepdim):", row_sum_keepdim)    # Row sum (keepdim): tensor([[5, 7, 9]])
print("Col sum (keepdim):", col_sum_keepdim)    # Col sum (keepdim): tensor([[ 6],[15]])


        示例 4: 指定输出数据类型
 

import torch

# 定义一个整数类型的张量
tensor = torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)

# 计算总和,并指定输出为浮点数类型
total_sum_float = torch.sum(tensor, dtype=torch.float32)

print("Total sum (float):", total_sum_float)   # Total sum (float): tensor(15.)
print("Type:", total_sum_float.dtype)          # Type: torch.float32

        注意事项

        数据类型:默认情况下,sum() 函数会保留输入张量的数据类型。如果需要特定的输出数据类型,可以使用 dtype 参数进行指定。

        维度选择:dim 参数可以是单个整数,也可以是一个整数元组,用于指定多个维度进行求和。

        性能优化:在处理大型张量时,sum() 函数通常是高效的,但根据具体需求,可能需要结合其他操作(如 keepdim)来优化内存使用和计算效率。

        总结

        torch.sum() 是 PyTorch 中用于计算张量元素总和的函数。它可以通过指定 dim 参数沿着特定维度进行求和,并可以通过 keepdim 参数控制输出张量的维度。灵活使用这些参数,可以满足多种求和需求。

correct = pred.eq(labels).sum() 的解读

        回过头来解读 pred.eq(labels).sum():

        pred.eq(labels):

        这部分代码调用了 eq() 函数,用于逐元素比较 pred 张量和 labels 张量。
        pred 是模型的预测结果,通常是一个张量,包含各个样本的预测类别。
        labels 是真实标签,实际上是对应的目标类别。
        eq() 函数会返回一个布尔张量,其中每个位置的值表示 pred 和 labels 的对应元素是否相等。
        例如,如果 pred 是 [1, 0, 2, 1],而 labels 是 [1, 1, 2, 1],那么 pred.eq(labels) 将返回 [True,         False, True, True]。
        .sum():

                在得到的布尔张量上调用 sum() 函数,True 会被视作 1,False 会被视作 0。
        因此,sum() 将返回预测正确的样本数量。
                之前的例子,[True, False, True, True] 的总和就是 3,表示 3 个样本的预测是正确的。
        赋值给 correct:

        最终,correct 变量将包含模型预测正确的样本数量。
        代码示例
        我们来看一个完整的示例,以便更好地理解这个过程:

import torch

# 模拟的预测值和真实标签
pred = torch.tensor([1, 0, 2, 1])   # 预测值
labels = torch.tensor([1, 1, 2, 1])   # 真实标签

# 计算预测正确的样本数量
correct = pred.eq(labels).sum()

print("Number of correct predictions:", correct.item())   # Number of correct predictions: 3


        总结
        pred.eq(labels) 用于逐元素比较预测值和真实标签,返回布尔值张量。
        .sum() 会统计 True 的数量,即预测正确的样本数量。
        这种方式在模型评估中很常见,尤其是在分类问题中,可以用来计算准确率等指标。

标签:torch,tensor,pred,sum,labels,张量,True,eq
From: https://blog.csdn.net/xulibo5828/article/details/143115452

相关文章

  • Towards Explainable Traffic Flow Prediction with Large Language Models
    <s>[INST]<<SYS>>Role:Youareanexperttrafficvolumepredictionmodel,thatcanpredictthefuturevolumevaluesaccordingtospatialtemporalinformation.Wewantyoutoperformthetrafficvolumepredictiontask,consideringthenea......
  • abc376E Max x Sum
    有序列A[N]和B[N],选出一组大小为K的下标,让A[i]的最大值乘以(B[i]之和)的结果最小,求最小值。1<=T<=2E5,1<=K<=N<=2E5,1<=A[i],B[i]<=1E6分析:因为A[i]跟B[i]要同步选,因此对下标排序,然后枚举每个A[i]作为最大值,从B[i]中选出最小的K个求和,得到结果,B[i]之和可以用堆来维护。#inclu......
  • 3DA3 C02 Predictive Data Analytics
    Assignment1,Commerce3DA3C02-PredictiveDataAnalyticsTocompletethisassignment,pleasecreateaJupyternotebook.Thecodeinyourjupyternotebookshouldprovideanswerstoquestionsaskedintheassignment.Pleasesubmittheassignmentbyuploadin......
  • [ABC376E] Max × Sum 题解
    [ABC376E]Max×Sum题解原题链接洛谷链接一道简单的推性质题,首先明确一个性质,子集是非连续的,所以在计算时并不用连续区间求。拿过题来,首先想的是枚举\(B\)的最小子集,但其复杂度为\(O(C_N^K)\)复杂度过高,不足以通过本题。于是转变思路,枚举\(A\)之中的最大值。若\(a_i......
  • 时延求和(Delay-and-Sum, DAS)波束形成器
    目录1.问题描述2.DAS波束形成3.DAS波束响应与波束图1.问题描述假设存在一个声源以及由N个阵元组成的麦克风阵列,且声源到各个阵元的传播信道只会引入时延与衰减,即......
  • C - sum(牛客小白月赛102)
    题目链接:C-sum题目描述:示例说明:解:这题典型的贪心问题,是求最小的操作次数。首先我们可以先算出这n个数的和s,s和sum的大小有三种情况。当s=sum时,一个数字也不用修改,答案为0。而剩下的两种情况可以合为一种情况来做。首先我们要知道如果把这n个数都变为相反数,则s也会变为......
  • OpenCity: Open Spatio-Temporal Foundation Models for Traffic Prediction
    1.数据准备在这个数据处理过程中,以数据集PEMS07M为例,整个数据抽取和划分过程如下:初始数据维度:原始训练数据data_train的维度为(12672,228,3)。其中:12672表示时间步数,代表不同的时间点采样的数据。228表示空间节点数(例如不同的交通站点)。3表示每个节点在每个......
  • 闯关leetcode——112. Path Sum
    大纲题目地址内容解题代码地址题目地址https://github.com/f304646673/leetcode/tree/main/112-Path-Sum内容GiventherootofabinarytreeandanintegertargetSum,returntrueifthetreehasaroot-to-leafpathsuchthataddingupallthevalues......
  • uniapp精仿微信源码,基于SumerUI和Uniapp前端框架的一款仿微信APP应用,界面漂亮颜值高,视
    uniapp精仿微信源码,基于SumerUI和Uniapp前端框架的一款仿微信APP应用,界面漂亮颜值高,视频商城小工具等,朋友圈视频号即时聊天用于视频,商城,直播,聊天,等等场景,源码分享sumer-weixin介绍uniapp精仿微信,基于SumerUI3.0和Uniapp前端框架的一款仿微信APP应用,界面漂亮颜值高,视频......
  • Subsequence and Prefix Sum
    SubsequenceandPrefixSum\(n\)才\(100\),\(a_i\)才\(20\),显然DP。设\(f_{i,j}\)表示第\(i\)个数,前\(i\)个数前缀和为\(j\)的方案数。显然,\(f_{0,0}=1\)。留意到如果\(j=0\),那么加入和不加入第\(i\)个数,最终的答案序列是一样的,因此此时加入第\(i\)个数对答......