import torch
a = torch.randint(0, 100, (6,3))
b = torch.Tensor([0, 1, 1, 2, 0, 2]).long()
b = b.unsqueeze(1)
b0 = b.view(-1, 1)
b2 = a.gather(1, b.view(-1, 1))
print(a)
print(a.shape)
print(b)
print(b.shape)
print(b2)
输出
tensor([[ 9, 10, 79],
[98, 43, 2],
[94, 82, 24],
[93, 72, 3],
[30, 29, 86],
[94, 25, 4]])
torch.Size([6, 3])
tensor([[0],
[1],
[1],
[2],
[0],
[2]])
torch.Size([6, 1])
tensor([[ 9],
[43],
[82],
[ 3],
[30],
[ 4]])
pytorch ssd里面gather的用法
# Compute max conf across batch for hard negative mining
#conf_data [3,8732,21] batch_conf[3*8732,21] [26196,21]
batch_conf = conf_data.view(-1, self.num_classes) #batch_conf [26196,21]
b1 = log_sum_exp(batch_conf) #[26196,1]
b00 = conf_t.view(-1, 1) #[26196, 1]
b2 = batch_conf.gather(1, conf_t.view(-1, 1)) #[26196,1]
#loss_c1 = F.cross_entropy(batch_conf, conf_t.view(-1))
#loss_c[26196,1] #https://zhuanlan.zhihu.com/p/153535799
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
标签:torch,b2,gather,batch,26196,pytorch,conf,view
From: https://www.cnblogs.com/yanghailin/p/17255414.html