一.锚框个数计算以及锚框高宽计算
二.代码实现
def multibox_prior(data, sizes, ratios):
"""生成以每个像素为中心具有不同形状的锚框"""
in_height, in_width = data.shape[-2:]
device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
boxes_per_pixel = (num_sizes + num_ratios - 1)
size_tensor = torch.tensor(sizes, device=device)
ratio_tensor = torch.tensor(ratios, device=device)
# 生成锚框的所有中心点
# 这里以0 1 2为例,中心点是0.5和1.5所以要加上0.5, /in_height是为了进行归一化
center_h = (torch.arange(in_height, device=device) + 0.5) / in_height
center_w = (torch.arange(in_width, device=device) + 0.5) / in_width
shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)#此时shift_y和shift_x一对一地形成了所有中心点下标
# 生成“boxes_per_pixel”个高和宽,
# 之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)
w = torch.cat((sizes[0] * torch.sqrt(in_height * ratio_tensor[:] / in_width),
size_tensor[1:] * torch.sqrt(in_height * ratio_tensor[0] / in_width)))
h = torch.cat((sizes[0] * torch.sqrt(in_width / ratio_tensor[:] / in_height),
size_tensor[1:] * torch.sqrt(in_width / ratio_tensor[0] / in_height)))
# 除以2来获得半高和半宽
anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(in_height * in_width, 1) / 2
# 每个中心点都将有“boxes_per_pixel”个锚框,
# 所以生成含所有锚框中心的网格,重复了“boxes_per_pixel”次
out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
dim=1).repeat_interleave(boxes_per_pixel, dim=0)
output = out_grid + anchor_manipulations
return output.unsqueeze(0) # batch(=1) * 锚框个数 * 4(左上和右下下标)
标签:tensor,shift,锚框,torch,目标,检测,device,height
From: https://www.cnblogs.com/sxq-blog/p/16633536.html