首页 > 其他分享 >Pytorch torch.meshgrid() 在目标检测中的应用

Pytorch torch.meshgrid() 在目标检测中的应用

时间:2023-01-25 21:45:17浏览次数:36  
标签:nn torch channels stride Pytorch meshgrid grid 坐标

概述

最近在学习目标检测的相关算法。在我看来目标检测要比分类、语义分割任务复杂的多,后者一般只需要为每个图像预测一个标签(分类)或者为每个像素预测一个标签(分割)。而目标检测还需要回归目标边界框同时进行分类,这使得目标检测的数据处理和训练比较复杂。

在目标检测中,一般是通过神经网络提取图像特征,得到下采样stride步幅的特征图,在特征图的每个cell上进行预测,最后将在特征层上预测的结果map回原图尺寸上。这时除了stride,还需要知道特征图的网格坐标。这时就可以用到torch.meshgrid()方法生成网格坐标。 

使用

在使用torch.meshgrid()前,简单说一下图像坐标。如下图所示,图像坐标的原点是左上角、x轴是宽、指向右;y轴是高、指向下。

需要注意的是,在pytorch中,tensor的shape一般都是(..., h, w)。要注意使用(x, y)时,坐标的对应关系,x对应的是w、y对应的是h。 我们需要得到grids网格坐标,也就是每个cell的左上角坐标。

这时我们就可以使用meshgrid方法。

feat = torch.randn(3, 4, 6) hsize, wsize = feat.shape[-2:] 
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij") 
grids = torch.stack((xv, yv), 2)

通过torch.meshgrid()生成了yv和xv,其内容如下:

# yv tensor([[0, 0, 0, 0, 0, 0], 
#            [1, 1, 1, 1, 1, 1], 
#            [2, 2, 2, 2, 2, 2],
#            [3, 3, 3, 3, 3, 3]])
#
# (hsize, wsize)

# xv tensor([[0, 1, 2, 3, 4, 5],
#            [0, 1, 2, 3, 4, 5],
#            [0, 1, 2, 3, 4, 5],
#            [0, 1, 2, 3, 4, 5]])
#
# (hsize, wsize)

 

然后通过torch.stack()将生成的xv与yv组合起来,就得到了grid网格坐标。

# tensor([[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0], [5, 0]],
#         [[0, 1], [1, 1], [2, 1], [3, 1], [4, 1], [5, 1]],
#         [[0, 2], [1, 2], ...
#                          [2, 3], [3, 3], [4, 3], [5, 3]]
#        ])
#
# (hsize, wsize, 2)

其实torch.meshgrid()方法中的indexing变量在最初版本中是没有的。我们生成坐标时,都是将yv放在前面,组合时再将xv,yv stack起来,形成符合直观的网格坐标(先行再列的顺序,对应了图像的xy坐标系)。

而现在新版本的indexing变量可以通过indexing="ij" or indexing="xy"控制格式,现在仍然保留原始的方法,使用默认的ij index,大概是出于和之前代码保持一致。

 

示例

 为了直观的感受到grid网格坐标与原图的关系,我们可以将grid坐标乘以stride步幅后,映射到原图中。原图如下:

随便写一个下采样步幅stride=32的网络。

conv = nn.Sequential( 
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 2
    nn.ReLU(),
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 4
    nn.ReLU(),
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 8
    nn.ReLU(),
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 16
    nn.ReLU(),
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 32
)

提取特征并且使用上述方法得到网格坐标

img = Image.open("flower.jpg").resize((640, 416)).convert('RGB')
img = ToTensor()(img)  # (3, 416, 600)

stride = 32
output = conv(img)  # (3, 13, 20)

hsize, wsize = output.shape[-2:]
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing='ij')
grid = torch.stack((xv, yv), 2)

将网格相对坐标乘以步幅得到原图绝对坐标,并通过修改img的灰度值使其可视化。

grid = grid * stride
grid = grid.view(-1, 2)
for x, y in grid:
    img[:, y, x] = 0.  # (y, x) -> (h, w)

得到最终的结果,我们看到特征图的grid cell 映射回原图的样子(黑点即为每个grid cell的坐标,为左上角)。

标签:nn,torch,channels,stride,Pytorch,meshgrid,grid,坐标
From: https://www.cnblogs.com/Brisling/p/17066790.html

相关文章

  • 3、python中的两大函数(pytorch中可用)
    1、dir():可以提供打开操作,让你看到里面有什么东西例子:查看torch下面会有哪些函数使用dir(torch),会出来函数名字,如果想细看函数里面是否还有东西可以使用dir(torch.函数名字......
  • pytorch环境安装
    1、下载anaconda,这个里面会提供很多包,所以不用下载多余的软件的,比如python2、一定要记住安装路径,后面选项都是默认,下载好之后测试一下,打开anacondaprompt界面,如果左侧括......
  • PyTorch图像分类全流程实战--迁移学习训练图像分类模型03
    教程同济子豪兄:https://space.bilibili.com/1900783斯坦福CS231N【迁移学习】中文精讲:https://www.bilibili.com/video/BV1K7411W7So斯坦福CS231N【迁移学习】官方笔记:h......
  • Pytorch:单卡多进程并行训练
    1导引我们在博客《Python:多进程并行编程与进程池》中介绍了如何使用Python的multiprocessing模块进行并行编程。不过在深度学习的项目中,我们进行单机多进程编程时一般不......
  • 深度学习的pytorch环境
    搭建主要分为下列及部分anaconda的安装和基本使用pycharm的安装和基本使用pytorch的安装第一章,anaconda不需要安装python,直接安装anaconda就行。因为里面自带一个......
  • Anaconda中Tensorflow和Pytorch环境的搭建
    Anaconda中Tensorflow和PyTorch环境的搭建title:r'Anaconda中Tensorflow和PyTorch环境的搭建'author:"hugaotuan"date:"1/13/2022"output:markdownConda......
  • mac安装pytorch环境
    已经安装好了anaconda退出base环境,condadeactivate创建的name叫mypytorch2进入mypytorch环境1查看现在环境中所有已安装的包condalist2安装python3.8condai......
  • pytorch入门
        最近因为语音识别项目,逐步深入接触到pytorch这个工具。使用Pytorch及胶水语言Python可以实现语音识别。    我从个人学习的经历谈一下入门的感受。 ......
  • pytorch 训练minist
    from__future__importprint_functionimportargparseimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.optimasoptimfromtorchvision......
  • PyTorch图像分类全流程实战--预训练模型预测图像分类02
    主要内容今天的任务是学习预训练模型的使用,模型是Resnet18,使用的torchvision包由流行的数据集、模型体系结构和通用的计算机视觉图像转换组成。简单地说就是常用数据集+常......