首页 > 其他分享 >torch中通过索引矩阵获取tensor

torch中通过索引矩阵获取tensor

时间:2022-11-13 19:46:41浏览次数:55  
标签:index tensor dim mask torch 矩阵 depth

1. torch中的索引矩阵

torch中有很多场景下都会生成索引矩阵,索引矩阵的shape和tensor的shape是相同的

a_tensor, a_index = torch.topk(a, dim=1)
# or
a_tensor, a_index = torch.sort(a, dim=1)

2. 通过索引矩阵获取tensor

通过index在指定维度提取tensor

depth_sample = torch.take_along_dim(depth_all, depth_index, dim=1)

3. 通过mask矩阵获取tensor

这里就是生成和tensor一样shape的zeros矩阵, 然后通过scatter和index和维度将index部分填充为1, 最后bool化, 这样就拿到了mask矩阵, 这样我们就可以用a[mask]直接取tensor了

mask = torch.zeros((batch_size_with_cams, Height, Width, 112)).to(mono_depth.device)
mask = mask.scatter(3, depth_index, 1).bool()

4. mask和take_along_dim的异同

这里做了几个操作

  1. 通过topk拿到了7个最大值, topk是会自己做排序的, 所以索引不是顺序的, 他根据tensor的大小做了相同的位置变化
  2. reference_depth是[0, 112]的顺序数字, 用take_along_dim按照depth_index取值, 再用sort重新排序拿到depth_sample_index
  3. 利用depth_sample_index对mono_depth_sample重新排序
  4. 最后xx和xxx是完全相同的
  5. 不过如果要排序的dim如果不是最后一维, 则xx和xxx不相同
mask = torch.zeros((batch_size_with_cams, Height, Width, 112)).to(mono_depth.device)
mask = mask.scatter(3, depth_index, 1).bool()
depth_all = self.reference_depth[None, None, None, :].repeat(Batch, Height, Width, 1)
depth_sample = torch.take_along_dim(depth_all, depth_index, dim=3)
depth_sample, depth_sample_index = torch.sort(depth_sample, dim=3)
xx = mono_depth.clone()[mask]
xxx = torch.take_along_dim(mono_depth_sample, depth_sample_index, dim=3).reshape(-1)
if (xx == xxx).all():
   print("*"*10)

标签:index,tensor,dim,mask,torch,矩阵,depth
From: https://www.cnblogs.com/qufang/p/16886711.html

相关文章