首页 > 其他分享 >基于SWIFT和Qwen1.5-14B-Chat进行大模型LoRA微调测试

基于SWIFT和Qwen1.5-14B-Chat进行大模型LoRA微调测试

时间:2024-03-08 09:14:11浏览次数:19  
标签:yldm0226 swift -- lora rate 14B Chat 0.0 LoRA

基于SWIFT和Qwen1.5-14B-Chat进行大模型LoRA微调测试

环境准备

基础环境

  • 操作系统:Ubuntu 18.04.5 LTS (GNU/Linux 3.10.0-1127.el7.x86_64 x86_64)
  • Anaconda3:Anaconda3-2023.03-1-Linux-x86_64
  • 根据服务器网络情况配置好conda源和pip源,此处使用的是超算山河源
  • 服务器硬件配置:CPU 96核;GPU 8×NVIDIA A100 40GB

环境安装

通过源代码安装SWIFT:

创建一个新的conda环境:

conda create --name swift python=3.8

激活刚刚创建的conda环境:

conda activate swift

下载SWIFT源码:

git clone https://github.com/modelscope/swift.git

切换到SWIFT路径:

cd /yldm0226/swift

安装SWIFT:

pip install -e .[llm]

非必要步骤:检查服务器cuda版本是否与当前安装的pytorch对应,详见:搭建一个大模型API服务

数据集准备

对于数据集,我们均采用json或jsonl的格式。

在做大模型SFT(Supervised Fine-Tuning)时,可以准备两种数据:

  1. 单轮问答
  2. 多轮对话

对于单轮问答数据,其格式可以为:

{"query": "11111", "response": "22222"}

对于多轮对话数据,其格式可以为:

{"query": "eeeee", "response": "fffff", "history": []}
{"query": "EEEEE", "response": "FFFFF", "history": [["AAAAA", "BBBBB"], ["CCCCC", "DDDDD"]]}

同时,也可以用以下两种格式的数据:

{"conversations": [{"from": "user", "value": "11111"}, {"from": "assistant", "value": "22222"}]}
{"conversations": [{"from": "user", "value": "aaaaa"}, {"from": "assistant", "value": "bbbbb"}, {"from": "user", "value": "ccccc"}, {"from": "assistant", "value": "ddddd"}]}
{"conversations": [{"from": "user", "value": "AAAAA"}, {"from": "assistant", "value": "BBBBB"}, {"from": "user", "value": "CCCCC"}, {"from": "assistant", "value": "DDDDD"}]}
{"messages": [{"role": "user", "content": "11111"}, {"role": "assistant", "content": "22222"}]}
{"messages": [{"role": "user", "content": "aaaaa"}, {"role": "assistant", "content": "bbbbb"}, {"role": "user", "content": "ccccc"}, {"role": "assistant", "content": "ddddd"}]}
{"messages": [{"role": "user", "content": "AAAAA"}, {"role": "assistant", "content": "BBBBB"}, {"role": "user", "content": "CCCCC"}, {"role": "assistant", "content": "DDDDD"}]}

在本文中,共使用了9个数据集,数据集的详细信息如下:

序号 数据集 简介 数据量
1 Chinese_medical_dialogue_six_department 中文医疗问答数据集,包括男科、内科、妇产科、肿瘤科、儿科、外科六个科室的问题。 792K
2 HuatuoGPT2_sft_instruct_GPT4 华佗GPT(HuatuoGPT)第二版训练数据集。 50K
3 ChatMed_Consult-v0.3 中文医疗在线问诊数据集ChatMed_Consult_Dataset的50w+在线问诊+ChatGPT回复。 500K
4 ChatMed_TCM-v0.2 以开源的中医药知识图谱为基础,采用以实体为中心的自指令方法(entity-centric self-instruct),调用ChatGPT得到11w+的围绕中医药的指令数据。 110K
5 QiZhen_sft_20k 包含20k训练数据(该数据集来自于启真医学知识库收集整理的真实医患知识问答数据以及在启真医学知识库的药品文本知识基础上,通过对半结构化数据设置特定的问题模板构造的指令数据)。 20K
6 Huatuo_Lite Huatuo-Lite 是在Huatuo26M数据集的基础上经过多次提纯和重写而精炼优化的数据集。它包含了18万个高质量的医疗问答对,并具有医院科室和相关疾病两个额外的数据维度。 180K
7 ZhongJing_CMtMedQA 仲景SFT训练集。 70K
8 DISC-Med-SFT_released 包含了超过47万个衍生于现有的医疗数据集重新构建得到的样本。采用了目标导向的策略,通过对于精心选择的几个数据源进行重构来得到SFT数据集。这些数据的作用在于帮助模型学习医疗领域知识,将行为模式与人类偏好对齐,并对齐真实世界在线医疗对话的分布情况。 514K
9 SZY_TCM_QA 私有数据集。 12K

以下是加载后的数据集信息:

[INFO:swift] train_dataset: Dataset({
    features: ['query', 'response', 'history'],
    num_rows: 2223540
})
[INFO:swift] val_dataset: Dataset({
    features: ['query', 'response', 'history'],
    num_rows: 22460
})

数据总量为2,246,000,从中抽取出约1%作为验证集,其余的作为训练集。

通过max_lengt=4096进行过滤后的数据集信息如下:

[INFO:swift] Dataset Token Length: 224.276768±159.001432, min=25.000000, max=4089.000000, size=2223411
[INFO:swift] Dataset Token Length: 224.254464±157.600093, min=28.000000, max=3086.000000, size=22459

编写微调脚本

SWIFT框架提供了部分大模型的微调脚本,可以在我们下载的源码中的swift/examples/pytorch/llm/scripts路径中找到这些脚本。如果这些脚本能够满足我们大部分的微调需求,我们可以选择直接对这些脚本进行修改。如果找不到我们需要的脚本,需要我们根据swift/docs/source/LLM中的命令行参数文档自行编写训练脚本。

以下是对Qwen1.5-14B-Chat进行LoRA微调的一个训练脚本:

nproc_per_node=8

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=$nproc_per_node \
MASTER_PORT=29500 \
swift sft \
    --model_type qwen1half-14b-chat \
    --model_id_or_path /yldm0226/models/Qwen1.5-14B-Chat \
    --model_revision master \
    --sft_type lora \
    --tuner_backend swift \
    --template_type qwen \
    --dtype AUTO \
    --output_dir /yldm0226/llm_sft_output \
    --ddp_backend nccl \
    --custom_train_dataset_path /yldm0226/data/1-Chinese_medical_dialogue_six_department.jsonl /yldm0226/data/2-HuatuoGPT2_sft_instruct_GPT4.jsonl /yldm0226/data/3-ChatMed_Consult-v0.3.jsonl /yldm0226/data/4-ChatMed_TCM-v0.2.jsonl /yldm0226/data/5-QiZhen_sft_20k.jsonl /yldm0226/data/6-Huatuo_Lite.jsonl /yldm0226/data/7-ZhongJing_CMtMedQA.jsonl /yldm0226/data/8-DISC-Med-SFT_released.jsonl /yldm0226/data/9-SZY_TCM_QA.jsonl \
    --train_dataset_sample -1 \
    --num_train_epochs 1 \
    --max_length 4096 \
    --check_dataset_strategy warning \
    --lora_rank 8 \
    --lora_alpha 32 \
    --lora_dropout_p 0.05 \
    --lora_target_modules ALL \
    --gradient_checkpointing true \
    --batch_size 1 \
    --weight_decay 0.01 \
    --learning_rate 1e-4 \
    --gradient_accumulation_steps $(expr 64 / $nproc_per_node) \
    --max_grad_norm 0.5 \
    --warmup_ratio 0.03 \
    --eval_steps 100 \
    --save_steps 100 \
    --save_total_limit 3 \
    --logging_steps 10 \
    --use_flash_attn false \
    --deepspeed default-zero3 \
    --save_only_model true

该脚本中的一些参数在基于SWIFT和Qwen1.5-14B-Chat进行大模型全参微调测试中已经解释过了,此处简单介绍一下与LoRA相关的几个参数,如果你想了解LoRA具体的原理,请阅读该论文LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

lora_rank:微调中的秩大小。秩的值并不是越大越好,此处设置的8是LoRA原论文中测试的最优解,根据论文中的结果,1或者2这种很小的秩的表现也是很好的。

lora_alpha:LoRA 微调中的缩放系数。

lora_dropout_p:LoRA 微调中的 Dropout 系数。

lora_target_modules:指定lora模块, 默认为['DEFAULT']. 如果lora_target_modules传入'DEFAULT' or 'AUTO', 则根据model_type查找MODEL_MAPPING中的lora_target_modules(默认指定为qkv)。如果传入'ALL', 则将所有的Linear层(不含head)指定为lora模块。 如果传入'EMBEDDING', 则Embedding层指定为lora模块。 如果内存允许, 建议设置成'ALL'。 当然, 你也可以设置['ALL', 'EMBEDDING'], 将所有的Linear和embedding层指定为lora模块。该参数只有当sft_type指定为'lora'时才生效。

deepspeed:用于指定deepspeed的配置文件的路径或者直接传入json格式的配置信息, 默认为None, 即不开启deepspeed. deepspeed可以节约显存。 SWIFT书写了默认的ZeRO-2配置文件, ZeRO-3配置文件。你只需要指定'default-zero2', 就会使用默认zero2配置文件; 指定'default-zero3', 就会使用默认的zero3配置文件。

测试

以下是训练过程中的部分输出:

{'loss': 3.91967845, 'acc': 0.46053511, 'learning_rate': 0.0, 'epoch': 0.0, 'global_step': 1}                                                                                                                                                             
{'loss': 3.13938289, 'acc': 0.50242286, 'learning_rate': 3.313e-05, 'epoch': 0.0, 'global_step': 10}                                                                                                                                                      
{'loss': 2.02636986, 'acc': 0.56641636, 'learning_rate': 4.31e-05, 'epoch': 0.0, 'global_step': 20}                                                                                                                                                       
{'loss': 1.51573572, 'acc': 0.62124624, 'learning_rate': 4.894e-05, 'epoch': 0.0, 'global_step': 30}                                                                                                                                                      
{'loss': 1.37469482, 'acc': 0.65222416, 'learning_rate': 5.308e-05, 'epoch': 0.0, 'global_step': 40}                                                                                                                                                      
{'loss': 1.44527245, 'acc': 0.64013515, 'learning_rate': 5.629e-05, 'epoch': 0.0, 'global_step': 50}                                                                                                                                                      
{'loss': 1.36220665, 'acc': 0.65485716, 'learning_rate': 5.891e-05, 'epoch': 0.0, 'global_step': 60}                                                                                                                                                      
{'loss': 1.34706726, 'acc': 0.65729899, 'learning_rate': 6.113e-05, 'epoch': 0.0, 'global_step': 70}                                                                                                                                                      
{'loss': 1.3558219, 'acc': 0.65412712, 'learning_rate': 6.305e-05, 'epoch': 0.0, 'global_step': 80}                                                                                                                                                       
{'loss': 1.38924046, 'acc': 0.6498558, 'learning_rate': 6.475e-05, 'epoch': 0.0, 'global_step': 90}                                                                                                                                                       
{'loss': 1.31848869, 'acc': 0.66292844, 'learning_rate': 6.626e-05, 'epoch': 0.0, 'global_step': 100}                                                                                                                                                     
Train:   0%|▌                                                                                                                                                                                                     | 100/34740 [20:07<113:54:29, 11.84s/it]
Val:  22%|████████████████████████████████████████████                            | 615/2808 [04:56<17:36,  2.07it/s]                                                                                                                    

训练一个epoch大约需要114小时;进行一次验证大约需要22分钟。(这里的时间只是一个大概值,在训练时,不同数据的处理速度不同,花费的总时间会一直变化)。

相比于全参,LoRA的微调方式能够节约大量的显存,因此我们可以将nproc_per_node设置的大一些,以提高训练的速度。

标签:yldm0226,swift,--,lora,rate,14B,Chat,0.0,LoRA
From: https://www.cnblogs.com/yourenbo/p/18060256

相关文章

  • 千问72-chat私有化部署
    千问开源的版本挺多,版本有1和1.5,参数有1.8~72B,模态有语言、语音、视觉。72B就有Qwen-72b-chat(聊天)和Qwen-72b(基础/预训练)两个版本,以下为简单的Qwen-72b-chat的坑:1、下载模型(魔塔社区),权重文件140+G2、新建虚拟环境,基础要求:python>3.8、pytorch>1.12、cuda>11.4;依赖:"transforme......
  • ChatGLM3 源码解析(三)
    RotaryEmbedding#旋转位置嵌入,应用于每一层Q和KclassRotaryEmbedding(nn.Module):def__init__(self,dim,rope_ratio=1,original_impl=False,device=None,dtype=None):super().__init__()#除法项定义inv_freq=1.0/(10000**(to......
  • 文心一言 VS 讯飞星火 VS chatgpt (209)-- 算法导论15.4 6题
    六、设计一个O(nlgn)时间的算法,求一个n个数的序列的最长单调递增子序列。(提示:注意到,一个长度为i的候选子序列的尾元素至少不比一个长度为i-1候选子序列的尾元素小。因此,可以在输入序列中将候选子序列链接起来。)要写代码的时候,请用go语言。文心一言,代码正常运行:要设计一......
  • ChatGPT 新增朗读功能;微软 Win11 鼠标悬停自动打开 Copilot 丨 RTE 开发者日报 Vol.15
      开发者朋友们大家好: 这里是「RTE开发者日报」,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(RealTimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内容仅代表编辑的个人观点......
  • 感谢ChatGPT,救了我狗的命!
    部门缺人,前端后端测试岗,base深圳/西安/北京:jinshuju.net/f/o38ijj前一段时间,国外一位小哥哥在推特上发布了一条消息,声称GPT-4拯救了自家狗狗的性命。这是怎么一回事呢?这个小哥哥养了一只两岁的边境牧羊犬,这只牧羊犬被诊断出患有蜱传疾病,这属于一种细菌性传染病。虽然小哥哥一......
  • 基于SWIFT和Qwen1.5-14B-Chat进行大模型全参微调测试
    基于SWIFT和Qwen1.5-14B-Chat进行大模型全参微调测试环境准备基础环境操作系统:Ubuntu18.04.5LTS(GNU/Linux3.10.0-1127.el7.x86_64x86_64)Anaconda3:Anaconda3-2023.03-1-Linux-x86_64根据服务器网络情况配置好conda源和pip源,此处使用的是超算山河源服务器硬件配置:CPU......
  • 文心一言 VS 讯飞星火 VS chatgpt (208)-- 算法导论15.4 5题
    五、设计一个O($n^2$)时间的算法,求一个n个数的序列的最长单调递增子序列。要写代码的时候,请用go语言。文心一言,抛panic:在Go语言中设计一个O(n^2)时间复杂度的算法来求一个n个数的序列的最长单调递增子序列(LongestIncreasingSubsequence,LIS)可以使用动态规划的方法......
  • ChatGPT用10秒画完一张UML流程图,而我用了。。。
    不用AI的程序员,失业潮真的快来临了。一张订单履约的流程图,我花了10分钟才完成,而ChatGPT绘图过程只用了10秒钟,基本可以达到同样的水平,通过ChatGPT可以显著提高画流程图的效率。订单履约流程是一系列精细协作的流程,从客户在销售平台下单开始,至商品交付用户手中结束。此过程跨越多......
  • ChatGPT4.0_API
    importopenaiimportjsonimportos#openai.api_key=get_api_key()openai.api_key="yourkey"#q="用python实现:提示手动输入3个不同的3位数区间,输入结束后计算这3个区间的交集,并输出结果区间"#q="WhatisthvbScript?"#q="翻译成日语:市场非常有潜力"#q="北京、......
  • ChatGLM3 源码解析(一)
    MLPclassMLP(torch.nn.Module):"""MLP把隐藏状态的尺寸从HidSize映射到4HidSize,执行非线性激活,然后再映射回HidSize"""def__init__(self,config:ChatGLMConfig,device=None):super(MLP,self).__init__()#控制是否添加偏......