首页 > 其他分享 >深度学习(蒸馏)

深度学习(蒸馏)

时间:2024-08-03 10:49:40浏览次数:12  
标签:蒸馏 nn self torch 学习 深度 model total teacher

模型蒸馏是指通过训练一个小而简单的模型来复制和学习一个大模型的知识和性能。这种方法通常用于减少模型的计算资源需求,加速推理过程或者使模型适用于资源受限的设备上。

步骤如下:

1. 准备教师模型和学生模型:

  教师模型:一个复杂的模型,这里用的是resnet。

  学生模型:简化的卷积神经网络,较少的参数和层次结构。

2. 定义损失函数:

  交叉熵损失:使用Softmax激活函数输出的概率分布,以及温度参数来平衡模型的软化度。

3. 训练学生模型:

  在训练过程中,通过比较学生模型预测和教师模型预测之间的差异来优化模型参数。

4. 优化和调整:

  可以尝试不同的模型结构、损失函数设置和超参数调整来优化学生模型的性能和效率。

5. 评估和比较:

  使用测试数据集评估学生模型的性能,并与未经蒸馏的模型以及教师模型进行比较。

测试代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import os

# 设置是否使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义数据转换
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 定义学生模型(简单的卷积神经网络)
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 512, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(512)
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.2)
        self.dropout3 = nn.Dropout(0.2)
        self.dropout4 = nn.Dropout(0.2)

        self.fc1 = nn.Linear(512*4*4, 256)
        self.fc2 = nn.Linear(256, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.dropout1(self.relu(self.bn1(self.conv1(x)))))
        x = self.pool(self.dropout2(self.relu(self.bn2(self.conv2(x)))))
        x = self.pool(self.dropout3(self.relu(self.bn3(self.conv3(x)))))
        x = x.view(x.size(0), -1)
        x = self.dropout4(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


# 测试模型
def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = correct / total
    print(f'Accuracy on test set: {100 * accuracy:.2f}%')


def trainTecher(model,trainLoader,testloader,optimizer,criterion):

    for epoch in range(5):
        model.train()
        correct = 0
        total = 0
        for inputs, labels in trainLoader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            print(epoch,loss.item(),f" train teacher Accuracy: {(100 * correct / total):.2f}%")

        test(model,testloader)

def trainStudent(model,teacher_model, trainloader,testloader):
    
    criterion = nn.KLDivLoss()  # KL散度损失函数
    optimizer = optim.AdamW(student_model.parameters(), lr=5e-4, weight_decay=1e-3)

    for epoch in range(20):
        model.train()
        correct_stu = 0
        correct_teh = 0
        total = 0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            teacher_outputs = teacher_model(inputs).detach()  # 使用教师模型的输出作为软标签

            loss = criterion(nn.functional.log_softmax(outputs/5, dim=1),
                             nn.functional.softmax(teacher_outputs/5, dim=1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct_stu += (predicted == labels).sum().item()

            _, predicted = torch.max(teacher_outputs.data, 1)
            correct_teh += (predicted == labels).sum().item()
            print(epoch,loss.item(),f" train student Accuracy: {(100 * correct_stu / total):.2f}%",f"{(100 * correct_teh / total):.2f}%")
        
        test(model,testloader)

if __name__ == '__main__':

    # 加载数据集
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

    # 定义教师模型
    teacher_model = models.resnet18(pretrained=True)
    teacher_model.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)  
    teacher_model.maxpool = nn.MaxPool2d(1, 1, 0)  
    teacher_model.fc = nn.Linear(teacher_model.fc.in_features, 10)
    teacher_model.to(device)

    student_model = StudentNet()
    student_model.to(device)

    total = sum([param.nelement() for param in teacher_model.parameters()])
    print("Number of parameter: %.2fM" % (total/1e6))   

    total = sum([param.nelement() for param in student_model.parameters()])
    print("Number of parameter: %.2fM" % (total/1e6))   

    if os.path.exists('resnet.pth'):    
        teacher_model.load_state_dict(torch.load('resnet.pth'))     
        teacher_model.eval()
    else:
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(teacher_model.parameters(), lr=5e-4, weight_decay=1e-3)
        trainTecher(teacher_model,trainloader,testloader,optimizer,criterion)
        torch.save(teacher_model.state_dict(), 'resnet.pth')

    trainStudent(student_model,teacher_model, trainloader,testloader)  

标签:蒸馏,nn,self,torch,学习,深度,model,total,teacher
From: https://www.cnblogs.com/tiandsp/p/18282642

相关文章

  • 一文读懂SEnet:如何让机器学习模型学会“重点观察”
    深入探讨一个在图像识别、自然语言处理等众多领域大放异彩的注意力模块——Squeeze-and-ExcitationNetworks(SEnet)。本文不仅会理论剖析SEnet的核心原理,还会手把手带你完成在TensorFlow和Pytorch这两个主流框架上的代码实现。准备好了吗?一起步入注意力机制的精妙世界。一、......
  • 功能齐全,深度适配 Home Assistant 的 CMPOWER W1 智能插排固件(附源码)
    固件特点:足够傻瓜,配网即用,无需添加/修改任何yaml文件,配网后HA中的mqttbroker会自动发现设备以及所有实体(包括计量)。支持计量功能,无需额外校准(电压,电流,功率,电量,频率,温度),基本满足日常使用。设备离线HA中自动更新状态显示设备不可用,当设备重新上线后HA中自动更新......
  • 基于深度学习的适应硬件的神经网络
    基于深度学习的适应硬件的神经网络设计旨在最大限度地利用特定硬件平台的计算和存储能力,提高模型的执行效率和性能。这些硬件包括图形处理单元(GPU)、张量处理单元(TPU)、现场可编程门阵列(FPGA)和专用集成电路(ASIC)。以下是关于适应硬件的神经网络的详细介绍:1.背景和动机硬件异构......
  • 基于深度学习的联邦学习
    基于深度学习的联邦学习(FederatedLearning,FL)是一种分布式机器学习方法,允许多个参与者(如设备或组织)在不共享原始数据的情况下共同训练模型。它通过在本地设备上训练模型,并仅共享模型更新(如梯度或参数),保护数据隐私和安全。以下是基于深度学习的联邦学习的详细介绍:1.背景和动......
  • 科大讯飞学生机平板怎么样2024 科大讯飞AI学习机T20 值得买吗
    科大讯飞AI学习机T20是一款基于24年AI技术积累的学习工具,致力于为广大学生提供更加智能化、高效的学习体验。该学习机采用了先进的AI技术,通过智能语音识别、自然语言处理等技术手段,实现了AI1对1类人辅导,能够针对不同学生的学习需求和水平,提供个性化的学习方案。不仅如此,科大讯飞A......
  • 动态规划学习笔记
    P3195求出玩具的前缀和\(S\)。设\(f_i\)表示区间\([1,i]\)的最大答案。开始应该是\(f_0=0\)。\(f_i=\max_{1\lej<i}f_j+(i+S_i-L-1-(j+S_j))^2\)。\(f_i=\max_{1\lej<i}f_j+(i+S_i-L-1)^2+(j+S_j)^2-2(i+S_i-L-1)(j+S_j)\)。设\(g_i=i+S_i,k=L+1\),那么\(f_......
  • Day16_1--JSP了解学习之EL表达式语言入门教程
    JSP(JavaServerPages)是一个用于生成动态网页的技术。EL(ExpressionLanguage)是JSP中的一种表达式语言,用于简化JSP页面中的Java代码,使其更易于书写和阅读。下面是对JSPEL表达式语言的简要介绍。1.什么是EL?EL(表达式语言)是JSP2.0引入的一种语言,它提供了一种简单的方法来访......
  • 深度解码:Java线程生命周期的神秘面纱
    在Java的编程宇宙中,线程是驱动应用程序的微小而强大的引擎。它们就像心脏的跳动,维持着程序的活力和响应性。今天,我们将深入探究线程的生命周期,理解它们从诞生到消逝的全过程,以及如何在不同状态下优雅地过渡。第二章:线程的活跃岁月执行阶段:运行与忙碌一旦被CPU选中,线程开......
  • 基于强化学习的倒立摆平衡车控制系统simulink建模与仿真
    1.算法仿真效果matlab2022a仿真结果如下(完整代码运行后无水印):      2.算法涉及理论知识概要       基于强化学习的倒立摆平衡车控制系统是一个典型的动态系统控制问题,它通过不断的学习和决策过程,使倒立摆维持在垂直平衡位置,即使受到外力干扰或系统内部噪......
  • 生成函数 学习笔记
    生成函数学习笔记有一部分没地方写的组合数学,先写这里。0.pre-learning1.上升/下降幂:\[n^{\underline{k}}=n\times(n-1)\times\cdots\times(n-k+1)\]称为\(n\)的下降幂。同理:\[n^{\overline{k}}=n\times(n-1)\times\cdots\times(n+k-1)\]称为\(......