1. torch中的索引矩阵
torch中有很多场景下都会生成索引矩阵,索引矩阵的shape和tensor的shape是相同的
a_tensor, a_index = torch.topk(a, dim=1)
# or
a_tensor, a_index = torch.sort(a, dim=1)
2. 通过索引矩阵获取tensor
通过index在指定维度提取tensor
depth_sample = torch.take_along_dim(depth_all, depth_index, dim=1)
3. 通过mask矩阵获取tensor
这里就是生成和tensor一样shape的zeros矩阵, 然后通过scatter和index和维度将index部分填充为1, 最后bool化, 这样就拿到了mask矩阵, 这样我们就可以用a[mask]直接取tensor了
mask = torch.zeros((batch_size_with_cams, Height, Width, 112)).to(mono_depth.device)
mask = mask.scatter(3, depth_index, 1).bool()
4. mask和take_along_dim的异同
这里做了几个操作
- 通过topk拿到了7个最大值, topk是会自己做排序的, 所以索引不是顺序的, 他根据tensor的大小做了相同的位置变化
- reference_depth是[0, 112]的顺序数字, 用take_along_dim按照depth_index取值, 再用sort重新排序拿到depth_sample_index
- 利用depth_sample_index对mono_depth_sample重新排序
- 最后xx和xxx是完全相同的
- 不过如果要排序的dim如果不是最后一维, 则xx和xxx不相同
mask = torch.zeros((batch_size_with_cams, Height, Width, 112)).to(mono_depth.device)
mask = mask.scatter(3, depth_index, 1).bool()
depth_all = self.reference_depth[None, None, None, :].repeat(Batch, Height, Width, 1)
depth_sample = torch.take_along_dim(depth_all, depth_index, dim=3)
depth_sample, depth_sample_index = torch.sort(depth_sample, dim=3)
xx = mono_depth.clone()[mask]
xxx = torch.take_along_dim(mono_depth_sample, depth_sample_index, dim=3).reshape(-1)
if (xx == xxx).all():
print("*"*10)
标签:index,tensor,dim,mask,torch,矩阵,depth
From: https://www.cnblogs.com/qufang/p/16886711.html