首页 > 其他分享 >LLM大模型GPT2微调尝试

LLM大模型GPT2微调尝试

时间:2024-05-20 19:52:46浏览次数:30  
标签:logging GPT2 微调 checkpoint token LLM model data

  1、作为安全从业者,以前搞逆向、挖漏洞、干渗透全靠人工推进,缺点很明显:

  • 无法自动化,甚至也无法半自动化,效率低(后续可以开发agent解决)
  • 知识面有限,存在很多知识盲点,导致遇到部分问题无法解决(可以通过增加知识库,然后rag检索或微调大模型解决)

      尝试了一些在线的大模型(chatGPT4、copilot),挖掘 https://www.cnblogs.com/theseventhson/p/13933230.html 这个例子的栈溢出漏洞,效果还行, 能找到漏洞;离线部署的LLM中尝试了secGPT、openbuddy等,同样的代码也能准确找到漏洞所在,比如:

    

     说明基于现有大数据训练得到的模型是能够发现并解决已知问题的,这条路别人已经走通了!

     2、 训练语料

     (1)众所周知,机器学习分监督学习于非监督学习。不论哪种,训练数据都是必须的,所以先整理一下我自己的博客园技术总结:

    

      huggingface上大模型sft微调一般的数据格式为:

{
    "instruction": "Please perform static application security testing on the above code",
    "input": "static void goodG2B1()\n{\n    int data;\n    /* Initialize data */\n    data = -1;\n    if(0)\n    {\n        /* INCIDENTAL: CWE 561 Dead Code, the code below will never run */\n        printLine(\"Benign, fixed string\");\n    }\n    else\n    {\n        /* FIX: Set data to a relatively small number greater than zero */\n        data = 20;\n    }\n    {\n        size_t i;\n        int *intPointer;\n        /* POTENTIAL FLAW: if data * sizeof(int) > SIZE_MAX, overflows to a small value\n         * so that the for loop doing the initialization causes a buffer overflow */\n        intPointer = (int*)malloc(data * sizeof(int));\n        if (intPointer == NULL) {exit(-1);}\n        for (i = 0; i < (size_t)data; i++)\n        {\n            intPointer[i] = 0; /* Potentially writes beyond the boundary of intPointer */\n        }\n        printIntLine(intPointer[0]);\n        free(intPointer);\n    }\n}\n",
    "output": "this code does not violate any security encoding standards"
}

   这类数据需要标注,耗时耗力,而博客园的技术文章都是没有标注的,所以sft暂时不考虑!如果非要把博客园的文章搞成instrct格式,可以尝试把文章内容作为input,标题或主要内容作为output,instruction改成自己关注的问题就好!

     (2)我因为没时间标注,直接把文章内容挨个放入json文件;为了减少训练时内存占用,也为了GPT2的max_length=1024的限制,把文章按照1024的长度截断,分别存入json文件;一共准备了19个json文件;

       

      用于微调的样本数据:

       

      3、模型选择:huggingface上模型已多大66w,这么多模型怎么选合适的?

       (1)模型参数量:N = l * 12* d^2; l是transformer block的个数;d 是 embedding的维度;

       (2)模型总的计算量:C=6*ND; N是模型参数量,D是训练集的token总数

       (3)模型BP需要存储::16byte*N + FFN(d、样本length、batch_size等)

        由于token数量只有200w,如果选择参数大的模型,肯定欠拟合(参考scaling laws,我这点token量都达不到人家尝试数据量的下限),所以只能选择参数小的模型;这里最终选择GPT2尝试!

      4、微调的方式有很多种,这里选择截至目前最优的lora尝试:

  1 import logging
  2 import torch
  3 from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
  4 from datasets import load_dataset
  5 from peft import get_peft_model, LoraConfig, TaskType
  6 import os
  7 
  8 logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
  9 
 10
 11 def preprocess_function(examples):
 12     inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=1024)
 13     inputs["labels"] = inputs["input_ids"].copy()
 14     return inputs
 15 
 16 if __name__ == '__main__':
 17
 18     model_name = "gpt2"
 19     tokenizer = GPT2Tokenizer.from_pretrained(model_name)
 20     model = GPT2LMHeadModel.from_pretrained(model_name)
 21     logging.info("Model loaded successfully")
 22 
 23     
 24     tokenizer.pad_token = tokenizer.eos_token
 25 
 26     
 27     training_args = TrainingArguments(
 28         output_dir="/root/huggingface/GPT2/Lora",
 29         overwrite_output_dir=True,
 30         num_train_epochs=1,
 31         per_device_train_batch_size=20,
 32         save_steps=10_000,
 33         save_total_limit=2,
 34         logging_dir="/root/huggingface/GPT2/logs",
 35         save_strategy="epoch",  
 36     )
 37 
 38     lora_config = LoraConfig(
 39         task_type=TaskType.CAUSAL_LM,
 40         r=16,
 41         lora_alpha=32,
 42         lora_dropout=0.1,
 43         target_modules=["attn.c_proj", "attn.c_attn"]
 44     )
 45 
 46     
 47     model = get_peft_model(model, lora_config)
 48     model.print_trainable_parameters()  # 打印可训练参数
 49 
 50     last_checkpoint = None
 51     checkpoint_prefix = "checkpoint"
 52 
 53     # 检查是否存在之前的检查点
 54     for i in range(19, 0, -1):
 55         checkpoint_dir = f"/root/huggingface/GPT2/{checkpoint_prefix}-{i}"
 56         if os.path.exists(checkpoint_dir):
 57             last_checkpoint = checkpoint_dir
 58             logging.info(f"last_checkpoint:{last_checkpoint}")
 59             break
 60 
 61     logging.info(f"last_checkpoint:{last_checkpoint}")
 62     # 从上一个检查点继续训练
 63     start_file_index = 1 if last_checkpoint is None else int(last_checkpoint.split('-')[-1]) + 1
 64 
 65     # 遍历每个数据文件并进行训练
 66     for i in range(start_file_index, 20):
 67         data_file = f'/root/huggingface/data/cyber_security{i}.json'
 68         dataset = load_dataset('json', data_files=data_file, split='train')
 69         logging.info(f"Dataset {data_file} loaded successfully")
 70         print(f"Dataset size: {len(dataset)}")
 71         tokenized_dataset = dataset.map(preprocess_function, batched=True)
 72         print(f"Tokenized dataset size: {len(tokenized_dataset)}")
 73 
 74         # 创建Trainer对象
 75         trainer = Trainer(
 76             model=model,
 77             args=training_args,
 78             train_dataset=tokenized_dataset,
 79         )
 80 
 81         logging.info(f"Before training on {data_file}")
 82 
 83         if last_checkpoint:
 84             logging.info(f"Loading from checkpoint {last_checkpoint}")
 85             trainer.train(resume_from_checkpoint=last_checkpoint)
 86             last_checkpoint = None  # 只在第一次加载检查点
 87         else:
 88             # 训练模型
 89             trainer.train()
 90 
 91         logging.info(f"After training on {data_file}")
 92 
 93         # 保存中间模型和检查点
 94         model.save_pretrained(f"/root/huggingface/GPT2/fine-tuned-model-{i}")
 95         tokenizer.save_pretrained(f"/root/huggingface/GPT2/fine-tuned-model-{i}")
 96         trainer.save_model(output_dir=f"/root/huggingface/GPT2/{checkpoint_prefix}-{i}")
 97         logging.info(f"Model saved successfully after training on {data_file}")
 98 
 99     # 最终保存模型
100     model.save_pretrained("/root/huggingface/GPT2/fine-tuned-model-final")
101     tokenizer.save_pretrained("/root/huggingface/GPT2/fine-tuned-model-final")
102     logging.info("Final model saved successfully")

  为了测试不同的参数的效果,我这里分别使用了r=8和r=16两个参数分别测试:

        r=8的trainable:

      

     cross entropy的loss:

     

     r=16的trainable:

     

     cross entropy的loss:     

      

    微调完成后推理:中文效果非常差,感觉完全没理解语义

  

    英语效果也好不到哪去:

  原因:

  • GPT2训练的中文语料少,对中文支持不足
  • 微调的语料太少,每批训练语料的epoch也只有1,loss最小也在3;

 

总结:

1、transformer架构,基础的block块核心由两部分构成:attention和FFN;FFN就是深度神经网络,十几年前就有了,没啥新的,所以attention是较大创新;

  • 在attention之前的embedding加上了位置编码,完美记录了token的位置,更好地表达语义;
  • attention本身通过q和k的向量内积提取token之间的相似度,找到每个token重要的context
  • q*k的乘积作为权重调整v的值,更好地表达token在当前context中的值;

     比如“apple”这个token,初始的embedding都是一样的,无法区分是水果的apple,还是电子产品的apple,只能根据不同的context确定apple这个词的语义;如果apple周围出现大量的pear、bananer、peach等,大概率指的是水果apple,理论上讲multihead中水果类head转换的v值会被q*k的内积增加,而电子产品类head转换的v值会被q*k的内积减小;所以,attention最核心的功能:根据context调整token的v值,让同一token在不同的context有不同的v值,更利于后续进一步的语义理解!目的感觉和互联网搜广推的“千人千面”很像啊!所以使用模型时处理token长度的max_length值不能太小(普通问答1024够了,专业的技术文章建议至少10240),否则提取的context信息有限,影响token v值的正确调整

2、transformer每个block都干同样的事:

  • attention:根据context调整token的v值
  • FFN:空间转换提取特征

      理论上讲:只要训练语料足够多,block数量足够多,经过层层变换,总能精准得到每个token的v值和语义信息,所以transformer的核心是“大力出奇迹”!

 

附:

1、知识库rag和微调的优劣对比:

 

标签:logging,GPT2,微调,checkpoint,token,LLM,model,data
From: https://www.cnblogs.com/theseventhson/p/18201535

相关文章

  • vllm服务推理参数
    stop:Listofstring。【生成文本时,碰到此token就会停下,但结果不会包含此token】stop_token_ids:Listofstring。【生成id时,碰到此id就会停止,会包含此id,比如tokenizer.eos_token_id[im_end]】最终判断是否停止,是两个的并集【同时考虑】参考:https://docs.vllm.ai/en/late......
  • LLM-文心一言:什么是电网WAMS?
    电网WAMS即广域测量系统(WideAreaMeasurementSystem),是基于同步向量技术构成的新一代电网动态监测和控制系统。WAMS的信息来源于PMU(相量测量单元)所采集的精确实时和同步信息,因此具有异地高精度同步向量测量、高速通信和快速反应等技术特点,非常适合大规模电网调度。它为电网实时......
  • LLM-文心一言:什么是SCADA系统
    SCADA系统,即数据采集与监视控制系统,是一种基于计算机的生产过程控制与调度自动化系统。它主要应用于电力、冶金、石油、化工、燃气、铁路等领域的数据采集与监视控制以及过程控制等诸多领域。在电力系统中,SCADA系统的应用最为广泛,技术发展也最为成熟。SCADA系统具有实时监控功能,......
  • bellmax-ford算的证明
    设\(dist[x]\)表示源点到\(x\)的最短路的距离(图中无负环),若对图中任意一条边\((x,y,z)\)满足\(dist[y]≤dist[x]+z\),那么\(dist\)就是最短路数组证:考虑任意一个点\(a\),假设找出了一条源点到\(a\)的最短路径{\(x_0,x_1,...,x_n,a\)},那么显然这条路径上\(x_0\)到任意一个点一定是最......
  • 一文彻底整明白,基于Ollama工具的LLM大语言模型Web可视化对话机器人部署指南
    在上一篇博文中,我们在本地部署了Llama38B参数大模型,并用Python写了一个控制台对话客户端,基本能愉快的与Llama大模型对话聊天了。但控制台总归太技术化,体验不是很友好,我们希望能有个类似ChatGPT那样的Web聊天对话界面,本博文就安排起来……上一篇Llama38B大模型部署......
  • GPT-4o 后 LLM 时代 RTC 需求讨论会丨社区伙伴活动分享
    我们将于5.19(周日)上午10点举办《GPT-4o后LLM时代RTC需求讨论会》讨论会。本期讨论会预计为线上闭门形式。 每期讨论会分别邀请了行业里面最资深的从业者、创业者朋友,每次讨论时间都限制在1.5-2小时,围绕主题,全程干货。 GPT-4o后最大的变动就是大模型的延迟时间大大缩短......
  • LLM实战:LLM微调加速神器-Unsloth + Qwen1.5
    1.背景上一篇介绍了基于训练加速框架Unsloth,微调训练Llama3的显卡资源占用及训练时间对比。近期Unsloth新增了Qwen1.5的模型适配,因此本qiang~马不停蹄地又进行了一次实验对比。对Unsloth的简介,感兴趣的客观可以查阅上一篇《LLM微调加速神器:Unsloth+LLama3》。2.实战本着......
  • LLM-通义千问:MyBatis-Plus 多租户插件
    MyBatis-Plus(简称MP)是一个对MyBatis的增强工具,在MyBatis的基础上只做增强不做改变,为简化开发、提高效率而生。多租户(Multi-Tenant)是一种软件架构技术,它允许多个组织(即“租户”)共享同一应用程序的实例,但每个组织的数据都是隔离且私有的。在多租户场景下,MyBatis-Plus提供了一种插件......
  • LLM-通义千问:掌握的Java知识过于零散,怎么系统地串一串
    要系统地整理和深化你的Java知识,可以遵循以下几个步骤和策略:1.回顾基础语言基础:从Java的基本语法开始,包括变量声明、数据类型、运算符、控制结构(if-else,for,while等)、数组和方法定义。面向对象编程:深入理解类与对象、继承、封装、多态这些核心概念,并实践如何设计良好的类......
  • 精通RAG架构:从0到1,基于LLM+RAG构建生产级企业知识库
    文章很长,且持续更新,建议收藏起来,慢慢读!疯狂创客圈总目录博客园版为您奉上珍贵的学习资源:免费赠送:《尼恩Java面试宝典》持续更新+史上最全+面试必备2000页+面试必备+大厂必备+涨薪必备免费赠送:《尼恩技术圣经+高并发系列PDF》,帮你实现技术自由,完成职业升级,薪......