首页 > 其他分享 >l1,l2,soomthl1

l1,l2,soomthl1

时间:2022-11-02 18:00:47浏览次数:36  
标签:bx2x96x128 num target predict object mask soomthl1 l2 l1

def smooth_l1_loss_modify(predict, target, mask, sigma=3):
    # predict: bx2x96x128
    # target : bx2x96x128
    # mask   : bx2x96x128
    num_object = mask.sum().item() / mask.size(1)
    sigma2 = sigma * sigma
    diff = predict[mask] - target[mask]
    diff_abs = diff.abs()
    near = (diff_abs < 1 / sigma2).float()
    far = 1 - near
    return (near * 0.5 * sigma2 * torch.pow(diff, 2) + far * (diff_abs - 0.5 / sigma2)).sum() / num_object
    
def l2_loss_modify(predict, target, mask):
    # predict: bx2x96x128
    # target : bx2x96x128
    # mask   : bx2x96x128
    num_object = mask.sum().item() / mask.size(1)
    if num_object == 0 : num_object = 1
    masked_predict = predict[mask]
    masked_target = target[mask]
    return torch.pow(masked_predict - masked_target, 2).sum() / num_object

def l1_loss_modify(predict, target, mask):
    # predict: bx2x96x128
    # target : bx2x96x128
    # mask   : bx2x96x128
    num_object = mask.sum().item() / mask.size(1)
    if num_object == 0 : num_object = 1
    masked_predict = predict[mask]
    masked_target = target[mask]
    return torch.abs(masked_predict - masked_target).sum() / num_object

 

标签:bx2x96x128,num,target,predict,object,mask,soomthl1,l2,l1
From: https://www.cnblogs.com/xiaoruirui/p/16851858.html

相关文章