1. 声明教师,学生网络
backbone_model = Net(gps=opt.gps, blocks=opt.blocks) backbone_model = backbone_model.to(device) ema_model = Net(gps=opt.gps, blocks=opt.blocks) ema_model = ema_model.to(device)
2. 教师网络不进行梯度更新
for param in backbone_model.parameters(): param.requires_grad = True for param in ema_model.parameters(): param.requires_grad = False
3. 教师网络不进行梯度更新将input放入到教师网络中
with torch.no_grad(): real_out = ema_model(real_hazy_img)
4. 将学生网络的参数传递到教师网络中
if opt.ema: state_dict_backbone = backbone_model.state_dict() state_dict_ema_model = ema_model.state_dict() for (k_backbone, v_backbone), (k_ema, v_ema) in zip(state_dict_backbone.items(), state_dict_ema_model.items()): assert k_backbone == k_ema assert v_backbone.shape == v_ema.shape if 'num_batches_tracked' in k_ema: v_ema.copy_(v_backbone) else: v_ema.copy_(v_ema * opt.momentum + (1. - opt.momentum) * v_backbone) # momentum=0.999
5. 测试时进行两个教师、学生两个模型的测试
ssim_eval_1, psnr_eval_1, ssim_eval_2, psnr_eval_2 = test(backbone_model, ema_model, test_loader)
pred = backbone_model(input) ema_pred = ema_model(input)
标签:opt,教师,ema,backbone,网络,学生,state,dict,model From: https://www.cnblogs.com/yyhappy/p/17560386.html