首页 > 其他分享 >lightweight openpose和hrnet的loss的区别

lightweight openpose和hrnet的loss的区别

时间:2023-02-05 17:11:06浏览次数:34  
标签:__ loss gt pred hrnet l1 openpose size

lightweight openpose的loss都用的是平方损失。

def l2_loss(input, target, mask, batch_size):
loss = (input - target) * mask
loss = (loss * loss) / 2 / batch_size

return loss.sum()

计算各个输出都用的是l2平方损失。

hrnet的loss都用的是关键点是平方损失而用的损失函数l1损失。并乘以权重

class HeatmapLoss(nn.Module):
def __init__(self):
super().__init__()

def forward(self, pred, gt, mask):
assert pred.size() == gt.size()
loss = ((pred - gt)**2) * mask
loss = loss.mean(dim=3).mean(dim=2).mean(dim=1).mean(dim=0)
return loss


class OffsetsLoss(nn.Module):
def __init__(self):
super().__init__()

def smooth_l1_loss(self, pred, gt, beta=1. / 9):
l1_loss = torch.abs(pred - gt)
cond = l1_loss < beta
loss = torch.where(cond, 0.5*l1_loss**2/beta, l1_loss-0.5*beta)
return loss

def forward(self, pred, gt, weights):
assert pred.size() == gt.size()
num_pos = torch.nonzero(weights > 0).size()[0]
loss = self.smooth_l1_loss(pred, gt) * weights
if num_pos == 0:
num_pos = 1.
loss = loss.sum() / num_pos
return loss

标签:__,loss,gt,pred,hrnet,l1,openpose,size
From: https://www.cnblogs.com/hahaah/p/17093617.html

相关文章