首页 > 其他分享 >pytorch ssd 代码疑惑, flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)

pytorch ssd 代码疑惑, flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)

时间:2023-03-13 11:23:54浏览次数:62  
标签:unsqueeze top flt self num output 100.0000

https://github.com/amdegroot/ssd.pytorch/blob/5b0b77faa955c1917b0c710d770739ba8fbff9b7/layers/functions/detection.py#L58

    def forward(self, loc_data, conf_data, prior_data):
        """
        Args:
            loc_data: (tensor) Loc preds from loc layers
                Shape: [batch,num_priors*4]
            conf_data: (tensor) Shape: Conf preds from conf layers
                Shape: [batch*num_priors,num_classes]
            prior_data: (tensor) Prior boxes and variances from priorbox layers
                Shape: [1,num_priors,4]
        """
        num = loc_data.size(0)  # batch size
        num_priors = prior_data.size(0)
        output = torch.zeros(num, self.num_classes, self.top_k, 5)
        conf_preds = conf_data.view(num, num_priors,
                                    self.num_classes).transpose(2, 1)

        # Decode predictions into bboxes.
        for i in range(num):
            decoded_boxes = decode(loc_data[i], prior_data, self.variance)
            # For each class, perform nms
            conf_scores = conf_preds[i].clone()

            for cl in range(1, self.num_classes):
                c_mask = conf_scores[cl].gt(self.conf_thresh)
                scores = conf_scores[cl][c_mask]
                if scores.size(0) == 0:
                    continue
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                boxes = decoded_boxes[l_mask].view(-1, 4)
                # idx of highest scoring and non-overlapping boxes per class
                ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
                output[i, cl, :count] = \
                    torch.cat((scores[ids[:count]].unsqueeze(1),
                               boxes[ids[:count]]), 1)
        flt = output.contiguous().view(num, -1, 5)
        _, idx = flt[:, :, 0].sort(1, descending=True)
        _, rank = idx.sort(1)
        flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
        return output

这段代码疑惑:

  flt = output.contiguous().view(num, -1, 5)
  _, idx = flt[:, :, 0].sort(1, descending=True)
  _, rank = idx.sort(1)
  flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
  return output

自己写了测试函数测试这段代码,发现对output没有任何影响啊

import torch

batch_size = 1
num_classes = 2
top_k = 4
#output[1, 2, 4, 5]
output = torch.zeros(batch_size, num_classes, top_k, 5) #[b,21,200,5]

a0 = torch.rand(4, 5)
a1 = torch.rand(4, 5)

output[0, 0, :] = a0
output[0, 1, :] = a1

print("==================== output==")
print(output)

flt = output.contiguous().view(batch_size, -1, 5)  # [b,21*200,5]
print("==================== flt==")
print(flt)
_, idx = flt[:, :, 0].sort(1, descending=True)
_, rank = idx.sort(1)

flt[(rank < top_k).unsqueeze(-1).expand_as(flt)].fill_(-100) ##src

#flt[(rank >= top_k).unsqueeze(-1).expand_as(flt)] = -100

print("====================last flt==")
print(flt)

print("====================last output==")
print(output)
==================== output==
tensor([[[[0.8621, 0.2626, 0.6104, 0.9218, 0.3547],
          [0.2925, 0.8051, 0.8366, 0.7753, 0.0779],
          [0.4998, 0.7976, 0.3025, 0.4936, 0.8532],
          [0.0884, 0.6303, 0.1796, 0.3239, 0.7133]],

         [[0.9649, 0.0333, 0.3988, 0.6702, 0.7215],
          [0.6214, 0.2352, 0.2797, 0.5770, 0.3067],
          [0.1836, 0.9779, 0.6925, 0.6443, 0.2149],
          [0.0182, 0.4632, 0.8495, 0.2121, 0.5690]]]])
==================== flt==
tensor([[[0.8621, 0.2626, 0.6104, 0.9218, 0.3547],
         [0.2925, 0.8051, 0.8366, 0.7753, 0.0779],
         [0.4998, 0.7976, 0.3025, 0.4936, 0.8532],
         [0.0884, 0.6303, 0.1796, 0.3239, 0.7133],
         [0.9649, 0.0333, 0.3988, 0.6702, 0.7215],
         [0.6214, 0.2352, 0.2797, 0.5770, 0.3067],
         [0.1836, 0.9779, 0.6925, 0.6443, 0.2149],
         [0.0182, 0.4632, 0.8495, 0.2121, 0.5690]]])
====================last flt==
tensor([[[0.8621, 0.2626, 0.6104, 0.9218, 0.3547],
         [0.2925, 0.8051, 0.8366, 0.7753, 0.0779],
         [0.4998, 0.7976, 0.3025, 0.4936, 0.8532],
         [0.0884, 0.6303, 0.1796, 0.3239, 0.7133],
         [0.9649, 0.0333, 0.3988, 0.6702, 0.7215],
         [0.6214, 0.2352, 0.2797, 0.5770, 0.3067],
         [0.1836, 0.9779, 0.6925, 0.6443, 0.2149],
         [0.0182, 0.4632, 0.8495, 0.2121, 0.5690]]])
====================last output==
tensor([[[[0.8621, 0.2626, 0.6104, 0.9218, 0.3547],
          [0.2925, 0.8051, 0.8366, 0.7753, 0.0779],
          [0.4998, 0.7976, 0.3025, 0.4936, 0.8532],
          [0.0884, 0.6303, 0.1796, 0.3239, 0.7133]],

         [[0.9649, 0.0333, 0.3988, 0.6702, 0.7215],
          [0.6214, 0.2352, 0.2797, 0.5770, 0.3067],
          [0.1836, 0.9779, 0.6925, 0.6443, 0.2149],
          [0.0182, 0.4632, 0.8495, 0.2121, 0.5690]]]])

Process finished with exit code 0

看了issue,有人也发现这个问题了。
https://github.com/amdegroot/ssd.pytorch/issues/168
可能是在pyt0.4上面有效,在高版本上面就无用了吧。
评论有人给出了解决方案就是代码里面我注释的那句话,
flt[(rank < top_k).unsqueeze(-1).expand_as(flt)].fill_(-100) ##src
改为
flt[(rank >= top_k).unsqueeze(-1).expand_as(flt)] = -100

这样是有效的,可以对output修改,输出如下:

==================== output==
tensor([[[[0.2341, 0.2941, 0.4434, 0.2481, 0.7296],
          [0.3081, 0.6865, 0.7391, 0.9371, 0.1801],
          [0.9775, 0.9983, 0.1749, 0.1505, 0.1860],
          [0.0919, 0.7764, 0.6790, 0.7079, 0.6412]],

         [[0.5518, 0.2866, 0.6437, 0.1184, 0.8749],
          [0.6722, 0.4248, 0.6839, 0.9222, 0.8995],
          [0.6662, 0.9287, 0.3097, 0.6207, 0.5590],
          [0.6176, 0.4586, 0.5354, 0.6958, 0.4959]]]])
==================== flt==
tensor([[[0.2341, 0.2941, 0.4434, 0.2481, 0.7296],
         [0.3081, 0.6865, 0.7391, 0.9371, 0.1801],
         [0.9775, 0.9983, 0.1749, 0.1505, 0.1860],
         [0.0919, 0.7764, 0.6790, 0.7079, 0.6412],
         [0.5518, 0.2866, 0.6437, 0.1184, 0.8749],
         [0.6722, 0.4248, 0.6839, 0.9222, 0.8995],
         [0.6662, 0.9287, 0.3097, 0.6207, 0.5590],
         [0.6176, 0.4586, 0.5354, 0.6958, 0.4959]]])
====================last flt==
tensor([[[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
         [-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
         [   0.9775,    0.9983,    0.1749,    0.1505,    0.1860],
         [-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
         [-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
         [   0.6722,    0.4248,    0.6839,    0.9222,    0.8995],
         [   0.6662,    0.9287,    0.3097,    0.6207,    0.5590],
         [   0.6176,    0.4586,    0.5354,    0.6958,    0.4959]]])
====================last output==
tensor([[[[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
          [-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
          [   0.9775,    0.9983,    0.1749,    0.1505,    0.1860],
          [-100.0000, -100.0000, -100.0000, -100.0000, -100.0000]],

         [[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
          [   0.6722,    0.4248,    0.6839,    0.9222,    0.8995],
          [   0.6662,    0.9287,    0.3097,    0.6207,    0.5590],
          [   0.6176,    0.4586,    0.5354,    0.6958,    0.4959]]]])

Process finished with exit code 0

标签:unsqueeze,top,flt,self,num,output,100.0000
From: https://www.cnblogs.com/yanghailin/p/17210669.html

相关文章

  • 虚假新闻检测(CALN)《Open-Topic False Information Detection on Social Networks with
    论文信息论文标题:Open-TopicFalseInformationDetectiononSocialNetworkswithContrastiveAdversarialLearning论文作者:GuanghuiMa,ChunmingHu,LingGe,Hon......
  • 网络爬虫-爬取豆瓣Top250
    一、选题的背景(10分)本次爬取的内容是豆瓣网站平均评分第一名到第二百五十名的电影名称,电影链接,电影封面图片链接,电影的概况和电影的相关信息。现在电影是人们一种很普遍的......
  • iTOP-RK3568开发板OTA升级包编译
    本节我们将编译三个版本的android镜像,V1.0.0版本、V1.0.1版本、V1.0.2版本,其中V1.0.0版本为基础版本用于烧写到rk3568开发板上,V.1.0.0升级到V1.0.1采用完全升级......
  • top详解
    第一行是任务队列信息,同uptime  命令的执行结果.其内容如下:01:06:48当前时间up1:22系统运行时间,格式为时:分1user当前登录用户数loadaverage:0.06......
  • jquery获取设置元素宽高位置height()、width()、offset()、position()、scrollTop()、
    ​​​​全栈工程师开发手册(作者:栾鹏)​​jquery系列教程2-style样式操作全解​​jquery获取设置元素宽高位置jquery的通过height()、width()、offset()、position()、s......
  • echart 解决setOption线残留
    前言:Antd+echarts我想要实现的是点击表的某一行自动生成对应的折线图,我在点击第一行生成5条线,我在点击第二行的时候,本该生成2条线,结果还是5条线;最开始我以为......
  • margin-left有效果但是margin-top没有用
     碰到一个bug,margin-left有效果但是margin-top没有用没有效果    原因是因为span、p、b等是内联函数,你需要先把这个元素设置为块级元素或块级内联函数才可......
  • Windows Docker Desktop 安装 Nacos
    前言以前都是在Linux虚拟机上的Docker安装应用,这次使用Windows10系统的DockerDesktop安装Nacos,所以用挂载文件就不是很方便了,这次采用启动参数的方式对配......
  • Windows10安装Docker Desktop
    DockerDesktop可以让我们在Windows环境下很方便的使用Docker,提供了很多便利。参考文档:https://blog.csdn.net/qq_39611230/article/details/108641842......
  • 获取元素到body顶部的距离,offsetTop和offsetParent,getBoundingClientRect
    最近在写一个可见曝光的sdk,是当元素显示在可见区域的时候才算作曝光,并上报给服务端。思路是在元素请求回来并渲染完成之后,计算元素距离document顶部的距离offset,当页面滚动......