首页 > 其他分享 >torch topk 使用

torch topk 使用

时间:2024-07-15 11:43:49浏览次数:16  
标签:dim pred torch topk 使用 indices True

torch topk 使用

这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

  • input:一个tensor数据
  • k:指明是得到前k个数据以及其index
  • dim: 指定在哪个维度上排序, 默认是最后一个维度
  • largest:如果为True,按照大到小排序; 如果为False,按照小到大排序
  • sorted:返回的结果按照顺序返回
  • out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。
假设一个tensor F, F∈R^N×D,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

1. python 版本

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364,  0.7912, -0.3263],
        [-0.8013, -0.9083,  0.7973,  0.1458, -0.9156],
        [-0.2334, -0.0142, -0.5493,  0.0673,  0.8185],
        [-0.4075, -0.1097,  0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],   #【0,0】代表 第一个样本最可能属于第一类别
        [2],   # 【1, 0】代表第二个样本最可能属于第二类别
        [4],
        [2]])
# indices_max等于indices
tensor([[True],
        [True],
        [True],
        [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True)  # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538,  1.8789,  0.4451, -0.2526],
        [-0.0413,  0.6366,  1.1155,  0.3484,  0.0395],
        [ 0.0365,  0.5158,  1.1067, -0.9276, -0.2124],
        [ 0.6232,  0.9912, -0.8562,  0.0148,  1.6413]])
# indices
tensor([[2, 3],
        [2, 1],
        [2, 1],
        [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。
其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

2. C++ 版

// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
inline ::std::tupleat::Tensor,at::Tensor Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const {
return at::_ops::topk::call(const_cast<Tensor&>(*this), k, dim, largest, sorted);
}




标签:dim,pred,torch,topk,使用,indices,True
From: https://www.cnblogs.com/michaelcjl/p/18302843

相关文章

  • 使用Spring Boot实现事务管理
    使用SpringBoot实现事务管理大家好,我是微赚淘客系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿!SpringBoot中的事务管理在现代的企业应用程序中,事务管理是确保数据完整性和一致性的关键部分。SpringBoot框架通过其强大的事务管理机制,为开发人员提供了简单而高效的方式来......
  • 使用Java实现定时任务调度
    使用Java实现定时任务调度大家好,我是微赚淘客系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿!定时任务调度的概述在软件开发中,定时任务调度是一项常见的需求,它允许开发人员周期性地执行特定的任务或操作。Java提供了多种方式来实现定时任务调度,其中包括传统的Timer类、Quar......
  • ThinkPHP6事件系统使用指南
    本文由ChatMoney团队出品在ThinkPHP6中,事件系统提供了一种优雅的方式来实现解耦和动态响应。你可以通过注册事件和对应的监听者来处理各种应用逻辑。事件注册闭包注册闭包是最简单的事件监听者,可以直接在注册时定义。Event::listen("ClosureEvent",function(){var_dump("C......
  • Typora使用Gitee和PicGo搭建免费图床
    Typora使用Gitee和PicGo搭建免费图床一、环境准备1、安装最新版Typora地址:Typora官方中文站2、安装Node.js地址:Node.js1、下载长期服务版LTS下载后,一直默认下一步安装即可2、验证是否安装成功命令行输入以下命令,出现版本号,说明安装成功!node-v二、配置Gi......
  • springboot使用logback日志出现LOG_PATH_IS_UNDEFINED文件夹的问题
    logback现在基本上已经成为springboot日志框架中使用最多的日志实现,在使用中与各中间件集成的一些注意事项记录如下 一SpringBoot中logback读取application.properties(application.yml)中的属性其中使用的时候发现了一个问题,就是如果使用的lobback配置文件的名称是logb......
  • 使用分布式锁解决淘客返利系统中的并发问题
    使用分布式锁解决淘客返利系统中的并发问题大家好,我是微赚淘客系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿!在大型淘客返利系统中,高并发是一个常见的挑战。为了保证数据的一致性和系统的稳定性,我们需要有效地管理并发访问,特别是在涉及关键资源或业务操作时。本文将......
  • ElementUI 本身没有提供年份范围选择组件,但可以通过封装两个年份选择器来实现类似的功
    ElementUI本身没有提供年份范围选择组件,但可以通过封装两个年份选择器来实现类似的功能。以下是一个使用Vue2和ElementUI实现年份范围选择器的示例代码: <script>exportdefault{name:'YearRangePicker',//接收父组件传入的年份范围数据props:{value:{......
  • Windows Server 2022 中SQL查询报错:error setting locale info for codepage 65001(取
    解决问题:刚开始我以为是SQLServer升级过程中遇到错误,后面仔细检查错误日志,发现我忽略了一个重要的错误信息“Thecodepage65001isnotsupportedbytheserver.”,codepage65001对应的编码为UTF-8,而数据库排序规则为Chinese_PRC_CI_AS,对应的codepage为936。原来这台SQLSe......
  • ThinkPHP6事件系统使用指南
    本文由ChatMoney团队出品在ThinkPHP6中,事件系统提供了一种优雅的方式来实现解耦和动态响应。你可以通过注册事件和对应的监听者来处理各种应用逻辑。事件注册闭包注册闭包是最简单的事件监听者,可以直接在注册时定义。Event::listen("ClosureEvent",function(){var_dump("C......