首页 > 其他分享 >目标检测-锚框

目标检测-锚框

时间:2022-08-28 20:24:12浏览次数:50  
标签:tensor shift 锚框 torch 目标 检测 device height

一.锚框个数计算以及锚框高宽计算

image

二.代码实现

def multibox_prior(data, sizes, ratios):
    """生成以每个像素为中心具有不同形状的锚框"""
    in_height, in_width = data.shape[-2:]
    device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
    boxes_per_pixel = (num_sizes + num_ratios - 1)
    size_tensor = torch.tensor(sizes, device=device)
    ratio_tensor = torch.tensor(ratios, device=device)

    # 生成锚框的所有中心点
    # 这里以0 1 2为例,中心点是0.5和1.5所以要加上0.5, /in_height是为了进行归一化
    center_h = (torch.arange(in_height, device=device) + 0.5) / in_height
    center_w = (torch.arange(in_width, device=device) + 0.5) / in_width
    shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
    shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)#此时shift_y和shift_x一对一地形成了所有中心点下标

    # 生成“boxes_per_pixel”个高和宽,
    # 之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)
    w = torch.cat((sizes[0] * torch.sqrt(in_height * ratio_tensor[:] / in_width),
                     size_tensor[1:] * torch.sqrt(in_height * ratio_tensor[0] / in_width)))
    h = torch.cat((sizes[0] * torch.sqrt(in_width / ratio_tensor[:] / in_height), 
                     size_tensor[1:] * torch.sqrt(in_width / ratio_tensor[0] / in_height)))
    # 除以2来获得半高和半宽
    anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(in_height * in_width, 1) / 2

    # 每个中心点都将有“boxes_per_pixel”个锚框,
    # 所以生成含所有锚框中心的网格,重复了“boxes_per_pixel”次
    out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
                dim=1).repeat_interleave(boxes_per_pixel, dim=0)
    output = out_grid + anchor_manipulations
    return output.unsqueeze(0) # batch(=1) * 锚框个数 * 4(左上和右下下标)

标签:tensor,shift,锚框,torch,目标,检测,device,height
From: https://www.cnblogs.com/sxq-blog/p/16633536.html

相关文章