非极大值抑制(Non-Maximum Suppression,NMS)是一种常用于目标检测和计算机视觉任务的技术,用于从重叠的检测框中选择最佳的候选框。以下是使用 PyTorch 实现标准的 NMS 算法的示例代码:
import torch
def nms(boxes, scores, iou_threshold):
sorted_indices = scores.argsort(descending=True)
selected_indices = []
while sorted_indices.numel() > 0:
current_index = sorted_indices[0]
selected_indices.append(current_index.item())
if sorted_indices.numel() == 1:
break
current_box = boxes[current_index]
other_boxes = boxes[sorted_indices[1:]]
ious = calculate_iou(current_box, other_boxes)
valid_indices = (ious <= iou_threshold).nonzero().squeeze()
if valid_indices.numel() == 0:
break
sorted_indices = sorted_indices[valid_indices + 1]
return selected_indices
def calculate_iou(box, boxes):
x1 = torch.max(box[0], boxes[:, 0])
y1 = torch.max(box[1], boxes[:, 1])
x2 = torch.min(box[2], boxes[:, 2])
y2 = torch.min(box[3], boxes[:, 3])
intersection_area = torch.clamp(x2 - x1 + 1, min=0) * torch.clamp(y2 - y1 + 1, min=0)
box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
boxes_area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)
iou = intersection_area / (box_area + boxes_area - intersection_area)
return iou
# 示例数据:框坐标和置信度得分
boxes = torch.tensor([[100, 100, 200, 200], [120, 120, 220, 220], [150, 150, 250, 250]])
scores = torch.tensor([0.9, 0.8, 0.7])
# NMS 参数
iou_threshold = 0.5
# 执行 NMS 算法
selected_indices = nms(boxes, scores, iou_threshold)
print("选择的索引:", selected_indices)
在此示例中,我们首先定义了 nms
函数来执行 NMS 算法。然后,我们实现了一个简单的 calculate_iou
函数来计算两个框的交并比(IoU)。最后,我们使用示例数据 boxes
和 scores
运行 NMS 算法,并打印出选定的索引。
标签:常用,nms,代码段,current,boxes,indices,sorted,NMS From: https://www.cnblogs.com/chentiao/p/17656004.html