蒸馏过程:
for epoch in range(epochs): student_model.train() for batch, (data, target) in enumerate(train_loader): student_logits = student_model(data) // 教师不更新 with torch.no_grad(): teacher_logits = teacher_model(data) # student与label的loss loss_cri = F.cross_entropy(y_s, target) # student与teacher的loss loss_kd = soft_cross_entropy(student_logits/temperature, teacher_logits/temperature) ## kd loss #p_s = F.log_softmax(student_logits/kd_T, dim=1) #p_t = F.softmax(teacher_logits/kd_T, dim=1) #loss_kd = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / student_logits.shape[0] # total loss loss = alpha * loss_cri + beta * loss_kd loss.backward() optimizer.zero_grad()
github链接:https://github.com/huawei-noah/Pretrained-Language-Model/blob/master/TinyBERT/task_distill.py
标签:loss,蒸馏,kd,模型,student,logits,data,teacher From: https://www.cnblogs.com/3511rjzn/p/17301994.html