class GIoULoss(nn.Module): def __init__(self): super().__init__() def forward(self, A, B): num_bbox = A.size(0) * A.size(2) ax, ay, ar, ab = A[:, 0], A[:, 1], A[:, 2], A[:, 3] bx, by, br, bb = B[:, 0], B[:, 1], B[:, 2], B[:, 3] xmax = torch.min(ar, br) ymax = torch.min(ab, bb) xmin = torch.max(ax, bx) ymin = torch.max(ay, by) cross_width = (xmax - xmin + 1).clamp(0) cross_height = (ymax - ymin + 1).clamp(0) cross = cross_width * cross_height union = (ar - ax + 1) * (ab - ay + 1) + (br - bx + 1) * (bb - by + 1) - cross iou = cross / union cxmin = torch.min(ax, bx) cymin = torch.min(ay, by) cxmax = torch.max(ar, br) cymax = torch.max(ab, bb) c = (cxmax - cxmin + 1) * (cymax - cymin + 1) return (1 - (iou - (c - union) / c)).sum() / num_bbox
标签:giou,ab,torch,cross,ar,ay,ax From: https://www.cnblogs.com/xiaoruirui/p/16851859.html