首页 > 其他分享 >LLM大模型:推理优化-知识蒸馏

LLM大模型:推理优化-知识蒸馏

时间:2024-07-16 17:09:42浏览次数:21  
标签:loss 蒸馏 模型 LLM model student logits 推理 teacher

    1、有些模型比较大,推理时的效果还不错,但非常耗费计算资源;并且产生token的速度也很慢,大概1秒1个token(我的RAG在最后一步使用的secGPT-13B大概就是这个速度),一个问题回答完毕要耗费分钟级别的时间,用户直接抓狂,继续提升推理的速度!

  大模型本质是大量的矩阵运算,想要提高效率,就要想办法提升矩阵运算的效率,大致的思路如下:

  • 知识蒸馏distillation:大模型去掉“水分”,保留“精华”后得到小模型
  • 模型剪枝:矩阵中某些元素毫无卵用,留着纯属“占着茅坑不拉屎”
  • 模型量化:FP32、FP16用INT8、INT4替代,减少存储和计算
  • 参数共享:部分层级之间共享参数,减少存储空间,提升计算效率
  • 低秩分解:原理类同Lora,把大矩阵分解成low -rank 小矩阵,减少存储空间,提升计算效率
  • 参数搜索:使用算法或启发式方法来确定最佳的参数配置

   这么多方法,相比之下知识蒸馏是比较流行的,效果也是比较好的,这里尝试一下对secGPT-13B做做知识蒸馏(他家已经有secGPT-mini了,具体怎么的来的还不清楚);

  2、不论是现在的LLM,还是传统的机器学习,最终的目的都是提升泛化性能,提高鲁棒性,让模型经过训练后,在新的数据上也能有很好的表现。同理,知识蒸馏的最终目的也是让student在新数据集上的表现接近teacher,该怎么去模仿学习teacher了?

        

   所有的神经网络简化图如上,让student的输出逼近teacher,有这么三种方式:

  • 直接让student的output接近teacher,其他的不care(只看最终的结果,不管中间过程),这就是所谓的response-based knowledge
  • 为了更好地让student逼近teacher,只看结果可能还不够,还要严控过程,让hidden layer的效果也逼近teacher,这就是feather-based knowledge
  • 再进一步,融合了前面两项,再加上input layer,对整个全流程(input->hidden-output)做系统性地模仿学习,捕捉样本之间的关系和teacher模型地全局结构信息,叫relation-based knowledge;

    第一种response-based的方式最简单,不需要考虑student和teacher的网络结构是否相同,只看结果两个模型的输出loss,根据loss反向调整sudent的参数即可,所以完全可以使用现成的模型作为student继续fine-tune!我搜寻了一遍小模型,知名度高、对中文支持又比较好的:gpt2-chinese-cluecorpussmall(1.2亿参数)、gpt2-distil-chinese-cluecorpussmall、t5-base-chinese-cluecorpussmall(2.4亿参数),这里最终选择gpt2-distil-chinese-cluecorpussmall(从名称看,可能已经蒸馏过了,应该验证了蒸馏效果还行)作为student,和secGPT-13B组成cp做distillation!确定好模型后,接下来就是怎么实操落地啦!换句话说,怎么让teacher把所有的knowleage都准确无误、毫无保留地传授给student了?

  3、大家回忆一下自己小时候上学的场景:坐在教室里,有各种教材课程,然后听老师讲课。课后自己做作业,老师批改作业,做错的题还要重新做,直到做对为止,整个流程经年累月后自己可以从老师那学到大量的knowledge,这一整个流程在LLM的knowledge distil中能不能被借鉴了?同样的训练数据,分别经过student和teacher模型做前向传播计算,然后对比双方的输出,差异作为loss,student根据这个loss调整自己的参数,直到loss变小为止,原理是不是很简单?接着的问题又来了:

  • teacher和sutdent之间怎么计算loss?换句话说loss函数怎么设计?
  • 既然都有训练数据了,为什么不直接用这些数据fine-tune student模型?为什么还要用teacher去训练student?

    GPT模型decoder部分最后一步都是softmax,输出vocab中每个token的概率值。传统的training过程是训练语料中token作为one-hot形式的ground truth,让GPT的softmax输出和ground truth计算KL散度的差异,用这个差异做BP调整模型的参数。由于ground truth的token都是one-hot形式,也就是当前token的概率是1,其他token的概率是0,所以这种目标称为hard targets,流行的teacher模型本身就是通过这种hard target训练出来的!

     

  问题是teacher模型decoder的输出是所有token的概率组成的向量,比如[0.1,0.6,0.05,0.15,0.04,0.06]这种,不是one-hot的hard targets,这种soft targets能被用于训练student么?为啥不直接用原始训练数据的hard targets去微调student了?

  还是以小时候上学读书为例:其实各种教材资料在市场上自己都能买到,为啥每天还要辛苦跑去学校读书了?为啥不自己在家里自学了?核心原因之一:教材内容展示的知识有限,有很多隐藏的知识是教材纸面上无法展示的!比如英语的发音,教材只能标识音标,具体怎么发音还是要靠老师教授和纠正;又比如数学定理的推导:有些教材推理过程并不详细,自学的时候可能看不懂为什么会从某些条件得到某些结论,期间还是要经验丰富的老师具体细化整个推导过程!总结一下,就是各种教材里面承载的明面知识有限,还有很多隐藏知识(dark knowledge)需要老师教授的!具体到LLM的知识蒸馏和训练,举例如下:

        

   上面这两图的数值,是2还是3了?是2还是7了?这就是hard target和soft target的本质区别!如果直接使用原始的数据训练student,用的就是hard target;如果使用teacher训练student,用的就是soft target!最核心的问题来了:为啥要用soft target训练student?soft target相比hard target,优势在哪

  仔细看上图,左边的数字确实是2,但是也有3的特征!右边数字确实是2,但也有7的特征,所以这两图也包含了其他数字的特征,所以如果直接简单粗暴地用hard target指定为2,那么3和7的特征是学不到的teacher训练时用了大量的语料,其他语料也有3和7的特征,所以这里用hard target也没啥问题;但知识蒸馏的场景下,训练语料是有限的,如果用hard target,student是无法提取3和7的特征的,会严重影响其他类别的判断!所以使用soft target最大的好处:

  指明类别之间的相对关系,可以让student学到其他类别的特征,大大提升模型的泛化性和鲁棒!这不就是所有机器学习的终极目标么?

  4、(1)确定了使用soft target后,就要确定loss的具体表达式了。为了提升泛化性和鲁棒性,对于负类不能像one-hot编码那样“赶尽杀绝”,需要适当给予一些概率,利于student模型提取特征,具体操作方式如下:

        

   T参数全名tempareture,用来调节概率的平滑性的。直观感受tempareture参数的作用:以logits = [-1,1,3,2,0.5]为例,不同的temperature对应不同的class probability,图示如下:

      

   看吧,T越大,各个不同类别的概率越接近,概率分布越平滑!

  (2)整个知识蒸馏的全流程:

         

  • teacher模型对input做feed forward计算,得到的结果经过softmax(t)后得到soft labels;
  • student模型同样对input做feed forward计算,然后分叉:
    • 和teacher一样,得到的结果经过softmax(t)后得到soft predictions;
    • 设置T=1,和原来的softmax效果一样,得到hard predictions;
  • soft labels和soft predictions,用于衡量teacher和student之间的差异
  • hard prediction和hard label,用于衡量student和ground truth之间的差异

  问题又特么来了:为啥要计算两个loss?这两个loss之间怎么取舍?

  • teacher虽然学识远远超过student,但是仍然有出错的可能,而这时候如果student在teacher的教授之外,可以同时参考到标准答案,就可以有效地降低被teacher偶尔“带偏”的可能性。
  • 既然又两个loss,那就人为分别设置权重呗,重要的loss权重高点,另一个权重低点!两个loss的权重是超参数,可以自由设置;

     (3)核心代码如下:

import json
from datasets import Dataset
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, BertTokenizer
import torch
from transformers import Trainer, TrainingArguments

# 加载数据
data_path = "/root/huggingface/data/"
data_files = ["distil.json"]

data = []
for file in data_files:
    with open(data_path + file, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))

# 将数据转换为 Hugging Face Dataset 格式
dataset = Dataset.from_list(data)

# 加载tokenizer
tokenizer = BertTokenizer.from_pretrained("/root/huggingface/gpt2-distil-chinese-cluecorpussmall")

# 添加pad_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})


def preprocess_function(examples):
    inputs = tokenizer(examples["query"], truncation=True, padding="max_length", max_length=512)
    outputs = tokenizer(examples["positive"], truncation=True, padding="max_length", max_length=512)
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": outputs["input_ids"],
        "labels_attention_mask": outputs["attention_mask"]
    }


tokenized_dataset = dataset.map(preprocess_function, batched=True)
split_datasets = tokenized_dataset.train_test_split(test_size=0.3)
train_dataset = split_datasets['train']
eval_dataset = split_datasets['test']

teacher_model = AutoModelForCausalLM.from_pretrained("/root/huggingface/secgpt", trust_remote_code=True,
                                                     device_map="cpu")
student_model = GPT2LMHeadModel.from_pretrained("/root/huggingface/gpt2-distil-chinese-cluecorpussmall")

# 确保教师模型和学生模型的词汇表大小一致
student_model.resize_token_embeddings(len(tokenizer))
teacher_model.resize_token_embeddings(len(tokenizer))

training_args = TrainingArguments(
    output_dir="/root/huggingface/SecGPT_distil",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
)


class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)

        student_logits = outputs.logits
        teacher_logits = teacher_outputs.logits.detach()

        # 确保 student_logits 和 teacher_logits 的形状一致
        if student_logits.shape != teacher_logits.shape:
            raise ValueError(
                f"Student logits shape {student_logits.shape} does not match teacher logits shape {teacher_logits.shape}")

        loss_fct = torch.nn.KLDivLoss(reduction="batchmean")
        temperature = 2.0
        alpha = 0.5
        beta = 0.5

        student_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
        teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)

        distillation_loss = loss_fct(student_probs, teacher_probs) * (temperature ** 2)

        # 计算student loss
        labels = inputs["labels"]
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct_student = torch.nn.CrossEntropyLoss()
        student_loss = loss_fct_student(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # 结合loss
        loss = alpha * distillation_loss + beta * student_loss

        return (loss, outputs) if return_outputs else loss


trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    teacher_model=teacher_model,
)

trainer.train()

student_model.save_pretrained("/root/huggingface/SecGPT_distil")
tokenizer.save_pretrained("/root/huggingface/SecGPT_distil")

  训练样本的数据格式:

{"query": "frida是什么?", "positive": "Frida是一款基于python + javascript 的hook框架,适用于android/ios/linux/win/osx等平台。Frida的动态代码执行功能,主要是在它的核心引擎Gum中用C语言来实现的"}
{"query": "怎么使用IDA?", "positive": "1、安装IDA   2、用IDA打开二进制文件,可以使用F5将汇编反编译成C语言伪代码   3、可以直接调试伪代码了解二进制代码逻辑"}
{"query": "怎么脱壳?", "positive": "对于一代、二代壳,可以直接使用frida dexdump从内存把正常的dex代码dump到磁盘"}

  我这里的样本少,明显不够,效果不也好,后续还要继续努力收集数据啊.......

 

参考:

1、https://blog.csdn.net/qq_52572775/article/details/138467295?spm=1001.2014.3001.5501  知识蒸馏Knowledge Distillation

2、https://www.jiqizhixin.com/articles/2024-03-18   LLM知识蒸馏最新综述

3、https://zhuanlan.zhihu.com/p/102038521  知识蒸馏经典之作

4、https://blog.csdn.net/jclian91/article/details/133896540  使用知识蒸馏提升模型推理性能

5、https://intellabs.github.io/distiller/knowledge_distillation.html  Knowledge Distillation

标签:loss,蒸馏,模型,LLM,model,student,logits,推理,teacher
From: https://www.cnblogs.com/theseventhson/p/18303028

相关文章

  • LeViT:Facebook提出推理优化的混合ViT主干网络 | ICCV 2021
    论文提出了用于快速图像分类推理的混合神经网络LeVIT,在不同的硬件平台上进行不同的效率衡量标准的测试。总体而言,LeViT在速度/准确性权衡方面明显优于现有的卷积神经网络和ViT,比如在80%的ImageNettop-1精度下,LeViT在CPU上比EfficientNet快5倍来源:晓飞的算法工程笔记公众号论......
  • 代码随想录算法训练营第六十六天 | Bellman_ford 队列优化算法(SPFA)、Bellman_ford之
    Bellman_ford队列优化算法(SPFA)题目链接:https://kamacoder.com/problempage.php?pid=1152文档讲解:https://programmercarl.com/kamacoder/0094.%E5%9F%8E%E5%B8%82%E9%97%B4%E8%B4%A7%E7%89%A9%E8%BF%90%E8%BE%93I-SPFA.html思路Bellman_ford算法每次松弛都是对所......
  • 代码随想录算法训练营第六十五天 | dijkstra(堆优化版)精讲、Bellman_ford 算法精讲、复
    dijkstra(堆优化版)精讲—卡码网:47.参加科学大会题目链接:https://kamacoder.com/problempage.php?pid=1047文档讲解:https://programmercarl.com/kamacoder/0047.%E5%8F%82%E4%BC%9Adijkstra%E5%A0%86.html思路当节点数多,边数少(稀疏图)时,可以考虑从边的角度出发,用堆来......
  • 面试准备【LLM】
    目录其他注意力过拟合的表现有哪些?BN训练和测试的区别在哪里?梯度下降的公式?反向传播优化器&Adam均方误差损失交叉熵损失梯度消失问题梯度爆炸问题权重正则化过拟合分词器BERT掩码语言建模(MLM)下一个句子预测NextSentencePrediction(NSP)BERT微调BERT模型创新BERT局限性BER......
  • 《昇思25天学习打卡营第17天|热门LLM及其他AI应用-基于MindNLP+MusicGen生成自己的个
    基于MindNLP+MusicGen生成自己的个性化音乐MusicGen是来自MetaAI的JadeCopet等人提出的基于单个语言模型(LM)的音乐生成模型,能够根据文本描述或音频提示生成高质量的音乐样本,相关研究成果参考论文《SimpleandControllableMusicGeneration》。MusicGen是一种单个语言模......
  • OpenAI 曝新项目「草莓」,提升 AI 推理能力;智谱 AI 开源视频理解模型丨 RTE 开发者日报
      开发者朋友们大家好: 这里是「RTE开发者日报」,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(Real-TimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内容仅代表编辑的个人观点,欢......
  • 【人工智能】 知识表示与推理(八数码 + 传教士与野人渡河)
    目录一、八数码难题1.需求分析2.数据结构、功能模块设计与说明2.1算法思路2.2数据结构3.核心代码与测试结果说明3.1核心代码3.2测试结果说明4. 存在的问题与体会4.1存在的问题4.2体会二、传教士与野人渡河1.需求分析2.数据结构、功能模块设计与说明......
  • 如何与 LLMs 有效沟通?6位数提示词工程师经验(LLMs 提示词小白必学)
    除非你活在太空里,完全脱离了现代社交媒体和新闻的关注,否则你不太可能错过大型语言模型    欢迎来到云闪世界。除非你活在太空里,完全脱离了现代社交媒体和新闻的关注,否则你不太可能错过大型语言模型(LLM)的突飞猛进带给我们生活中的革命性进步。LLM的演变。......
  • LLM用于时序预测真的不行,连推理能力都没用到
    语言模型真的能用于时序预测吗?根据贝特里奇头条定律(任何以问号结尾的新闻标题,都能够用「不」来回答),答案应该是否定的。事实似乎也果然如此:强大如斯的LLM并不能很好地处理时序数据。时序,即时间序列,顾名思义,是指一组按照时间发生先后顺序进行排列的数据点序列。在很多领......
  • 【论文阅读】DeepREL通过自动化关系 API 推理对深度学习库进行模糊测试
    通过自动化关系API推理对深度学习库进行模糊测试论文基本信息ESEC/FSE’22,November14–18,2022,Singapore,Singapore时间:2022-11-07CCFA原文:https://doi.org/10.1145/3540250.3549085摘要近年来,深度学习(DL)受到广泛关注。同时,深度学习系统中的错误可能导致严重后......