首页 > 其他分享 >pytorch gather函数

pytorch gather函数

时间:2024-02-04 21:34:30浏览次数:28  
标签:index torch 15 函数 gather ids pytorch print tensor

转载于:https://www.zhihu.com/question/562282138/answer/2947708508?utm_id=0
官方文档链接:
https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather

torch.gather()的定义非常简洁:
在指定dim上,从原tensor中获取指定index的数据, 看到这个核心定义,我们很容易想到gather()的基本想法就是从完整数据中按索引取值,比如下面从列表中按索引取值:

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子。
对于深度学习常用的批量tensor数据,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor
中取出指定乱序索引下的数据,因此其用途如下:
方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的。

实验

ex0 输入行向量index,并替换行索引(dim=0):

import torch
tensor_0 = torch.arange(3, 12).view(3, 3) #[3, 3]
index = torch.tensor([[2, 1, 0]]) #[1, 3]

tensor_1 = tensor_0.gather(0, index)
print("====>> tensor0")
print(tensor_0)
print("====>> tensor1")
print(tensor_1)

#输出如下:
====>> tensor0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
====>> tensor1
tensor([[9, 7, 5]])

过程:

ex1 输入行向量index,并替换列索引(dim=1)

import torch
tensor_0 = torch.arange(3, 12).view(3, 3) #[3, 3]
index = torch.tensor([[2, 1, 0]]) #[1, 3]
tensor_2 = tensor_0.gather(1, index)

print("====>> tensor0")
print(tensor_0)
print("====>> tensor2")
print(tensor_2)

输出:
====>> tensor0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
====>> tensor2
tensor([[5, 4, 3]])

ex2 输入行向量index,并替换列索引(dim=1)

index = torch.tensor

([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

tensor([[5],
        [7],
        [9]])

ex3 输入二维矩阵index,并替换列索引(dim=1)

index = torch.tensor([[0, 2], 
                      [1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

tensor([[3, 5],
        [7, 8]])

![](/i/l/?n=24&i=blog/1047308/202402/1047308-20240204211453119-1392642658.png)

##要点
###归纳出torch.gather()的使用要点
###输出value的shape等于输入index的shape
###索引input时,其索引构成过程:对输入index中的每个value的索引,只在对应的dim上将该索引的索引值替换为输入index中的对应value,就构成了对input的索引
###用得到的input的索引,对input进行索引得到输出value

##其他应用示例, 在mae的代码中,
https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_mae.py#L123
![](/i/l/?n=24&i=blog/1047308/202402/1047308-20240204211808058-908922774.png)

如上代码两次argsort代码示例:

import torch

noise = torch.rand(3, 5)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

print(noise)
print(ids_shuffle)
print(ids_restore)

输出如下:

tensor([[0.8787, 0.3496, 0.4642, 0.1852, 0.2965],
[0.0701, 0.1533, 0.1716, 0.1579, 0.5323],
[0.0827, 0.5038, 0.4169, 0.1121, 0.9830]])
tensor([[3, 4, 1, 2, 0],
[0, 1, 3, 2, 4],
[0, 3, 2, 1, 4]])
tensor([[4, 2, 3, 0, 1],
[0, 1, 3, 2, 4],
[0, 3, 2, 1, 4]])


##gather mae中的用法

import torch

D = 8
x = torch.randint(0, 20, (3, 5, D))

noise = torch.randint(0, 20, (3, 5))
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

len_keep = 2
ids_keep = ids_shuffle[:, :len_keep]

index = ids_keep.unsqueeze(-1).repeat(1, 1, D)
x_masked = torch.gather(x, dim=1, index=index)

print("====>>> x")
print(x)

print("====>>> noise")
print(noise)

print("====>>> ids_shuffle")
print(ids_shuffle)

print("====>>> ids_keep.unsqueeze(-1)")
print(ids_keep.unsqueeze(-1))

print("====>>> ids_keep")
print(ids_keep)

print("====>>> index")
print(index)

print("====>>> x_masked")
print(x_masked)



输出如下:

====>>> x
tensor([[[13,  6,  7, 15,  1,  9,  7, 17],
         [15, 15, 11, 15, 17,  4,  6, 15],
         [10, 18,  5,  6, 18, 10, 19,  2],
         [11, 19, 19, 11, 10, 11,  7, 11],
         [18, 15, 17,  5,  7,  5,  9,  5]],

        [[ 4, 12,  5,  7, 12, 15, 14,  6],
         [15, 12, 13, 14,  8,  5, 15, 11],
         [12, 17, 12, 11,  2,  9,  8,  1],
         [18,  9,  6, 12, 19, 17, 10,  3],
         [11,  4,  9, 18,  1, 17,  0, 10]],

        [[18,  5, 11, 18, 19,  6,  0, 19],
         [19, 15, 12,  9, 18,  3, 18,  1],
         [15,  3, 17, 15,  3, 16,  0,  6],
         [ 1,  4, 12, 10,  4, 10, 10,  4],
         [18, 13,  3, 16,  1,  2, 15, 17]]])
====>>> noise
tensor([[ 8, 16, 16,  4, 17],
        [ 0, 13,  4, 19, 17],
        [14, 17,  1,  9,  4]])
====>>> ids_shuffle
tensor([[3, 0, 1, 2, 4],
        [0, 2, 1, 4, 3],
        [2, 4, 3, 0, 1]])
====>>> ids_keep.unsqueeze(-1)
tensor([[[3],
         [0]],

        [[0],
         [2]],

        [[2],
         [4]]])
====>>> ids_keep
tensor([[3, 0],
        [0, 2],
        [2, 4]])
====>>> index
tensor([[[3, 3, 3, 3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2, 2, 2, 2]],

        [[2, 2, 2, 2, 2, 2, 2, 2],
         [4, 4, 4, 4, 4, 4, 4, 4]]])
====>>> x_masked
tensor([[[11, 19, 19, 11, 10, 11,  7, 11],
         [13,  6,  7, 15,  1,  9,  7, 17]],

        [[ 4, 12,  5,  7, 12, 15, 14,  6],
         [12, 17, 12, 11,  2,  9,  8,  1]],

        [[15,  3, 17, 15,  3, 16,  0,  6],
         [18, 13,  3, 16,  1,  2, 15, 17]]])

标签:index,torch,15,函数,gather,ids,pytorch,print,tensor
From: https://www.cnblogs.com/yanghailin/p/18007025

相关文章

  • 【pwn】pwnable_start --只有read和write函数的getshell
    首先查一下程序的保护情况保护全关!!!然后看ida逻辑ida的结果很简洁,只有一段汇编代码,我们再来看看nc情况现在我们来分析一下汇编代码 mov  ecx,esp            ;addr.text:08048089B214            mov  dl......
  • PyTorch下,使用list放置模块,导致计算设备不一的报错
    报错在复现Transformer代码的训练阶段时,发生报错:RuntimeError:Expectedalltensorstobeonthesamedevice,butfoundatleasttwodevices,cuda:0andcpu!解决方案通过next(linear.parameters()).device确定model已经在cuda:0上了,同时输入model.forward()的......
  • PyTorch 2.2 中文官方教程(十六)
    介绍torch.compile原文:pytorch.org/tutorials/intermediate/torch_compile_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0注意点击这里下载完整的示例代码作者:WilliamWentorch.compile是加速PyTorch代码的最新方法!torch.compile通过将PyTorch代码JIT编译成优化的......
  • PyTorch 2.2 中文官方教程(十八)
    开始使用完全分片数据并行(FSDP)原文:pytorch.org/tutorials/intermediate/FSDP_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0作者:HamidShojanazeri,YanliZhao,ShenLi注意在github上查看并编辑本教程。在大规模训练AI模型是一项具有挑战性的任务,需要大量的计算能力和资源......
  • PyTorch 2.2 中文官方教程(十九)
    使用RPC进行分布式管道并行原文:pytorch.org/tutorials/intermediate/dist_pipeline_parallel_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0作者:ShenLi注意在github中查看并编辑本教程。先决条件:PyTorch分布式概述单机模型并行最佳实践开始使用分布式RPC框......
  • PyTorch 2.2 中文官方教程(二十)
    移动设备在iOS上进行图像分割DeepLabV3原文:pytorch.org/tutorials/beginner/deeplabv3_on_ios.html译者:飞龙协议:CCBY-NC-SA4.0作者:JeffTang审阅者:JeremiahChung介绍语义图像分割是一种计算机视觉任务,使用语义标签标记输入图像的特定区域。PyTorch语义图像分割De......
  • PyTorch 2.2 中文官方教程(十一)
    使用PyTorchC++前端原文:pytorch.org/tutorials/advanced/cpp_frontend.html译者:飞龙协议:CCBY-NC-SA4.0PyTorchC++前端是PyTorch机器学习框架的纯C++接口。虽然PyTorch的主要接口自然是Python,但这个PythonAPI坐落在一个庞大的C++代码库之上,提供了基础数据......
  • PyTorch 2.2 中文官方教程(十二)
    自定义C++和CUDA扩展原文:pytorch.org/tutorials/advanced/cpp_extension.html译者:飞龙协议:CCBY-NC-SA4.0作者:PeterGoldsboroughPyTorch提供了大量与神经网络、任意张量代数、数据处理和其他目的相关的操作。然而,您可能仍然需要更定制化的操作。例如,您可能想使用在论......
  • PyTorch 2.2 中文官方教程(十三)
    在C++中注册一个分发的运算符原文:pytorch.org/tutorials/advanced/dispatcher.html译者:飞龙协议:CCBY-NC-SA4.0分发器是PyTorch的一个内部组件,负责确定在调用诸如torch::add这样的函数时实际运行哪些代码。这可能并不简单,因为PyTorch操作需要处理许多“层叠”在彼此之......
  • PyTorch 2.2 中文官方教程(十四)
    参数化教程原文:译者:飞龙协议:CCBY-NC-SA4.0作者:MarioLezcano注意点击这里下载完整示例代码在本教程中,您将学习如何实现并使用此模式来对模型进行约束。这样做就像编写自己的nn.Module一样容易。对深度学习模型进行正则化是一项令人惊讶的挑战。传统技术,如惩罚方法,通......