转载于:https://www.zhihu.com/question/562282138/answer/2947708508?utm_id=0
官方文档链接:
https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather
torch.gather()的定义非常简洁:
在指定dim上,从原tensor中获取指定index的数据, 看到这个核心定义,我们很容易想到gather()的基本想法就是从完整数据中按索引取值,比如下面从列表中按索引取值:
lst = [1, 2, 3, 4, 5]
value = lst[2] # value = 3
value = lst[2:4] # value = [3, 4]
上面的取值例子是取单个值或具有逻辑顺序序列的例子。
对于深度学习常用的批量tensor数据,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor
中取出指定乱序索引下的数据,因此其用途如下:
方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的。
实验
ex0 输入行向量index,并替换行索引(dim=0):
import torch
tensor_0 = torch.arange(3, 12).view(3, 3) #[3, 3]
index = torch.tensor([[2, 1, 0]]) #[1, 3]
tensor_1 = tensor_0.gather(0, index)
print("====>> tensor0")
print(tensor_0)
print("====>> tensor1")
print(tensor_1)
#输出如下:
====>> tensor0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
====>> tensor1
tensor([[9, 7, 5]])
过程:
ex1 输入行向量index,并替换列索引(dim=1)
import torch
tensor_0 = torch.arange(3, 12).view(3, 3) #[3, 3]
index = torch.tensor([[2, 1, 0]]) #[1, 3]
tensor_2 = tensor_0.gather(1, index)
print("====>> tensor0")
print(tensor_0)
print("====>> tensor2")
print(tensor_2)
输出:
====>> tensor0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
====>> tensor2
tensor([[5, 4, 3]])
ex2 输入行向量index,并替换列索引(dim=1)
index = torch.tensor
([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5],
[7],
[9]])
ex3 输入二维矩阵index,并替换列索引(dim=1)
index = torch.tensor([[0, 2],
[1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[3, 5],
[7, 8]])
![](/i/l/?n=24&i=blog/1047308/202402/1047308-20240204211453119-1392642658.png)
##要点
###归纳出torch.gather()的使用要点
###输出value的shape等于输入index的shape
###索引input时,其索引构成过程:对输入index中的每个value的索引,只在对应的dim上将该索引的索引值替换为输入index中的对应value,就构成了对input的索引
###用得到的input的索引,对input进行索引得到输出value
##其他应用示例, 在mae的代码中,
https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_mae.py#L123
![](/i/l/?n=24&i=blog/1047308/202402/1047308-20240204211808058-908922774.png)
如上代码两次argsort代码示例:
import torch
noise = torch.rand(3, 5)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
print(noise)
print(ids_shuffle)
print(ids_restore)
输出如下:
tensor([[0.8787, 0.3496, 0.4642, 0.1852, 0.2965],
[0.0701, 0.1533, 0.1716, 0.1579, 0.5323],
[0.0827, 0.5038, 0.4169, 0.1121, 0.9830]])
tensor([[3, 4, 1, 2, 0],
[0, 1, 3, 2, 4],
[0, 3, 2, 1, 4]])
tensor([[4, 2, 3, 0, 1],
[0, 1, 3, 2, 4],
[0, 3, 2, 1, 4]])
##gather mae中的用法
import torch
D = 8
x = torch.randint(0, 20, (3, 5, D))
noise = torch.randint(0, 20, (3, 5))
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
len_keep = 2
ids_keep = ids_shuffle[:, :len_keep]
index = ids_keep.unsqueeze(-1).repeat(1, 1, D)
x_masked = torch.gather(x, dim=1, index=index)
print("====>>> x")
print(x)
print("====>>> noise")
print(noise)
print("====>>> ids_shuffle")
print(ids_shuffle)
print("====>>> ids_keep.unsqueeze(-1)")
print(ids_keep.unsqueeze(-1))
print("====>>> ids_keep")
print(ids_keep)
print("====>>> index")
print(index)
print("====>>> x_masked")
print(x_masked)
输出如下:
====>>> x
tensor([[[13, 6, 7, 15, 1, 9, 7, 17],
[15, 15, 11, 15, 17, 4, 6, 15],
[10, 18, 5, 6, 18, 10, 19, 2],
[11, 19, 19, 11, 10, 11, 7, 11],
[18, 15, 17, 5, 7, 5, 9, 5]],
[[ 4, 12, 5, 7, 12, 15, 14, 6],
[15, 12, 13, 14, 8, 5, 15, 11],
[12, 17, 12, 11, 2, 9, 8, 1],
[18, 9, 6, 12, 19, 17, 10, 3],
[11, 4, 9, 18, 1, 17, 0, 10]],
[[18, 5, 11, 18, 19, 6, 0, 19],
[19, 15, 12, 9, 18, 3, 18, 1],
[15, 3, 17, 15, 3, 16, 0, 6],
[ 1, 4, 12, 10, 4, 10, 10, 4],
[18, 13, 3, 16, 1, 2, 15, 17]]])
====>>> noise
tensor([[ 8, 16, 16, 4, 17],
[ 0, 13, 4, 19, 17],
[14, 17, 1, 9, 4]])
====>>> ids_shuffle
tensor([[3, 0, 1, 2, 4],
[0, 2, 1, 4, 3],
[2, 4, 3, 0, 1]])
====>>> ids_keep.unsqueeze(-1)
tensor([[[3],
[0]],
[[0],
[2]],
[[2],
[4]]])
====>>> ids_keep
tensor([[3, 0],
[0, 2],
[2, 4]])
====>>> index
tensor([[[3, 3, 3, 3, 3, 3, 3, 3],
[0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2, 2, 2]],
[[2, 2, 2, 2, 2, 2, 2, 2],
[4, 4, 4, 4, 4, 4, 4, 4]]])
====>>> x_masked
tensor([[[11, 19, 19, 11, 10, 11, 7, 11],
[13, 6, 7, 15, 1, 9, 7, 17]],
[[ 4, 12, 5, 7, 12, 15, 14, 6],
[12, 17, 12, 11, 2, 9, 8, 1]],
[[15, 3, 17, 15, 3, 16, 0, 6],
[18, 13, 3, 16, 1, 2, 15, 17]]])
标签:index,torch,15,函数,gather,ids,pytorch,print,tensor
From: https://www.cnblogs.com/yanghailin/p/18007025