首页 > 其他分享 >pytorch张量索引

pytorch张量索引

时间:2022-11-10 10:36:24浏览次数:78  
标签:dim tensor torch print 张量 索引 pytorch 0.0000 Matrix


一、pytorch返回最值索引

1 官方文档资料

1.1 torch.argmax()介绍

 返回最大值的索引下标

函数:
torch.argmax(input, dim, keepdim=False) → LongTensor

返回值:
Returns the indices of the maximum values of a tensor across a dimension.

参数:
input (Tensor) – the input tensor.
dim (int) – the dimension to reduce. If None, the argmax of the flattened input is returned.
keepdim (bool) – whether the output tensor has dim retained or not. Ignored if dim=None.

1.2 torch.argmin()介绍 

 返回最小值的索引下标

函数:
torch.argmin(input, dim, keepdim=False) → LongTensor

返回值:
Returns the indices of the mimimum values of a tensor across a dimension.

参数:
input (Tensor) – the input tensor.
dim (int) – the dimension to reduce. If None, the argmax of the flattened input is returned.
keepdim (bool) – whether the output tensor has dim retained or not. Ignored if dim=None.

2 代码示例

2.1 torch.argmax()代码示例

>>> import torch
>>> Matrix = torch.randn(2,2,2)
>>> print(Matrix)
tensor([[[ 0.3772, -0.1143],
[ 0.2217, -0.1897]],

[[ 0.1488, -0.8758],
[ 1.7734, -0.5929]]])
>>> print(Matrix.argmax(dim=0))
tensor([[0, 0],
[1, 0]])
>>> print(Matrix.argmax(dim=1))
tensor([[0, 0],
[1, 1]])
>>> print(Matrix.argmax(dim=2))
tensor([[0, 0],
[0, 0]])
>>> print(Matrix.argmax())
tensor(6)

2.2 torch.argmin()代码示

>>> import torch
>>> Matrix = torch.randn(2,2,2)
>>> print(Matrix)
tensor([[[ 0.5821, 0.2889],
[ 0.4669, -0.3135]],

[[-0.4567, 0.2975],
[-1.5084, 0.7320]]])
>>> print(Matrix.argmin(dim=0))
tensor([[1, 0],
[1, 0]])
>>> print(Matrix.argmin(dim=1))
tensor([[1, 1],
[1, 0]])
>>> print(Matrix.argmin(dim=2))
tensor([[1, 1],
[0, 0]])
>>> print(Matrix.argmin())
tensor(6)

 二、pytorch返回任意值索引

tens = tensor([[  101,   146,  1176, 21806,  1116,  1105, 18621,   119,   102,     0,
0, 0, 0],
[ 101, 1192, 1132, 1136, 1184, 146, 1354, 1128, 1127, 117,
1463, 119, 102],
[ 101, 6816, 1905, 1132, 14918, 119, 102, 0, 0, 0,
0, 0, 0]])
idxs = torch.tensor([(i == 101).nonzero() for i in tens])

from torch import tensor

tens = torch.tensor([[ 101, 146, 1176, 21806, 1116, 1105, 18621, 119, 102, 0,
...: 0, 0, 0],
...: [ 101, 1192, 1132, 1136, 1184, 146, 1354, 1128, 1127, 117,
...: 1463, 119, 102],
...: [ 101, 6816, 1905, 1132, 14918, 119, 102, 0, 0, 0,
...: 0, 0, 0]])

(tens == 101).nonzero()[:, 1]
tensor([0, 0, 0])

三、pytorch 只保留tensor的最大值或最小值,其他位置置零

如下,x是输入张量,dim指定维度,max可以替换成min 

import torch

if __name__ == '__main__':

x = torch.randn([1, 3, 4, 4]).cuda()

mask = (x == x.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
result = torch.mul(mask, x)

print(x)
print(mask)
print(result)

输出效果:

tensor([[[[-0.8807,  0.1029,  0.0184,  1.2695],
[-0.0934, 1.0650, -0.2927, 0.0049],
[ 0.2338, -1.8663, 1.2763, 0.7248],
[-1.5138, 0.6834, 0.1463, 0.0650]],

[[ 0.5020, 1.6078, -0.0104, 1.2042],
[ 1.8859, -0.4682, -0.1177, 0.5197],
[ 1.7649, 0.4585, 0.6002, 0.3350],
[-1.1384, -0.0325, 0.8490, 0.6080]],

[[-0.5618, 0.5388, -0.0572, -0.7240],
[-0.3458, 1.3494, -0.0603, -1.1562],
[-0.3652, 1.1885, 1.6293, 0.4134],
[ 1.3009, 1.2027, -0.8711, 1.3321]]]], device='cuda:0')
tensor([[[[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 1],
[0, 0, 0, 0]],

[[1, 1, 0, 0],
[1, 0, 0, 1],
[1, 0, 0, 0],
[0, 0, 1, 0]],

[[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[1, 1, 0, 1]]]], device='cuda:0', dtype=torch.int32)
tensor([[[[-0.0000, 0.0000, 0.0184, 1.2695],
[-0.0000, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, 0.7248],
[-0.0000, 0.0000, 0.0000, 0.0000]],

[[ 0.5020, 1.6078, -0.0000, 0.0000],
[ 1.8859, -0.0000, -0.0000, 0.5197],
[ 1.7649, 0.0000, 0.0000, 0.0000],
[-0.0000, -0.0000, 0.8490, 0.0000]],

[[-0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 1.3494, -0.0603, -0.0000],
[-0.0000, 1.1885, 1.6293, 0.0000],
[ 1.3009, 1.2027, -0.0000, 1.3321]]]], device='cuda:0')

Process finished with exit code 0

 四、使用pytorch获取tensor每行中的top k

    ???老铁,这么简单的问题还有问,自己解决去!!!!

参考:

​python - How Pytorch Tensor get the index of specific value - Stack Overflow


https://stackoverflow.com/questions/47863001/how-pytorch-tensor-get-the-index-of-specific-value​

​How Pytorch Tensor get the index of elements?


https://stackoverflow.com/questions/57933781/how-pytorch-tensor-get-the-index-of-elements​

​https://discuss.pytorch.org/t/keep-the-max-value-of-the-array-and-0-the-others/14480/8


https://discuss.pytorch.org/t/keep-the-max-value-of-the-array-and-0-the-others/14480/8​​ ​​javascript:v


标签:dim,tensor,torch,print,张量,索引,pytorch,0.0000,Matrix
From: https://blog.51cto.com/u_13206712/5839954

相关文章

  • pytorch tensor 张量常用方法介绍
    1. view()函数PyTorch 中的view()函数相当于numpy中的resize()函数,都是用来重构(或者调整)张量维度的,用法稍有不同。>>>importtorch>>>re=torch.tensor([1,......
  • pytorch TensorDataset和DataLoader区别
    TensorDatasetTensorDataset可以用来对tensor进行打包,就好像python中的zip功能。该类通过每一个tensor的第一个维度进行索引。因此,该类中的tensor第一维度必须......
  • 【pyfaidx】纯Python实现的FASTA随机索引库
    前言基因组序列的提取,有不少强大的工具像samtools,bedtools,之前也提到pybedtools提取序列。不过pybedtools是对bedtools提供一个Python接口,除了安装pybedtools外,还需......
  • pytorch入门
    初衷:看不懂论文开源代码参考:B站小土堆(土堆yyds~)   PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili 1.环境配置参考:(39条消息)win10......
  • B+树索引适⽤的条件
    假设有张存储人基本信息的表,DDL如下:CREATETABLEperson_info(    idINTNOTNULL auto_increment,    nameVARCHAR(100)NOTNULL,    birthdayDATENOT......
  • ElasticSearch Java API之索引操作
    背景:​​ElasticSearchJava客户端连接ElasticSearch​​以这篇博客为基础​​ElasticSearch:简单介绍以及使用Docker部署ElasticSearch和Kibana​​这篇博客简单部署了E......
  • 【MySQL】深入理解MySQL索引原理(MySQL专栏启动)
    本文导读本篇文章博主对索引做了一个较为初步地概述,主要有2种主要的索引的数据结构b+tree和hash的数据结构,b+树的覆盖索引和回表进行分析,并对b+树存放记录、如何优化B+树索......
  • 【MySQL】深入理解MySQL索引优化器工作原理
    本文导读本文将解读MySQL数据库查询优化器(CBO)的工作原理。简单介绍了MySQLServer的组成,MySQL优化器选择索引额原理以及SQL成本分析,最后通过select查询总结整个查询过程。......
  • 一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】
    ????声明:作为全网AI领域干货最多的博主之一,❤️不负光阴不负卿❤️????深度学习:#超分重建、一文读懂????超分重建经典网络SRGAN详尽教程????最近更新:2022年2月28......
  • 使用PyTorch实现简单的AlphaZero的算法(1):背景和介绍
    在本文中,我们将在PyTorch中为ChainReaction[2]游戏从头开始实现DeepMind的AlphaZero[1]。为了使AlphaZero的学习过程更有效,我们还将使用一个相对较新的改进,称为“Playout......