首页 > 其他分享 >ChatGLM-6B-PT微调

ChatGLM-6B-PT微调

时间:2023-09-28 15:45:01浏览次数:45  
标签:6B PT -- ChatGLM2 steps ChatGLM output model

目录

开发环境

ChatGLM2-6B源码

git clone https://github.com/THUDM/ChatGLM2-6B.git

下载模型

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm2-6b-int4

安装依赖

cd ChatGLM2-6B
# 新建模型目录,将模型复制到model目录下
mkdir model
# 安装依赖
pip install -r requirements.txt -i https://mirror.sjtu.edu.cn/pypi/web/simple
# 运行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖
pip install rouge_chinese nltk jieba datasets transformers[torch] -i https://pypi.douban.com/simple/

下载ADGEN数据集

微调前

image

cd ptuning
# 复制训练数据集到ptuning目录中
cp -r /mnt/AdvertiseGen .

# 微调训练
# 训练集目录 ptuning/AdvertiseGen/
# 模型目录 ChatGLM2-6B/model/
# 模型训练输出目录 ptuning/output/
# max_steps 最大训练步数
# save_steps 保存步骤数
# logging_steps 记录日志的频率
# quantization_bit 控制量化的精度
# pre_seq_len 预先设定的序列长度
# learning_rate 使用的学习率
# gradient_accumulation_steps 连续计算梯度的步数
torchrun --standalone --nnodes=1 --nproc-per-node=1 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --preprocessing_num_workers 10 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /home/ChatGLM2-6B/model/chatglm2-6b-int4 \
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 128 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 2e-2 \
    --pre_seq_len 128 \
    --quantization_bit 4

修改训练步数

torchrun --standalone --nnodes=1 --nproc-per-node=1 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --preprocessing_num_workers 10 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /home/ChatGLM2-6B/model/chatglm2-6b-int4 \
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 128 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 100 \
    --logging_steps 10 \
    --save_steps 50 \
    --learning_rate 2e-2 \
    --pre_seq_len 128 \
    --quantization_bit 4

微调后

# 微调后
import os
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModel

model_path = "model/chatglm2-6b-int4"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# 微调后代码
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join("/home/ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt--/checkpoint-3000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

# 使用 Markdown 格式打印模型输出
# from IPython.display import display, Markdown, clear_output

# def display_answer(model, query, history=[]):
#     for response, history in model.stream_chat(
#             tokenizer, query, history=history):
#         clear_output(wait=True)
#         display(Markdown(response))
#     return history

# display_answer(model, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞")

prompt = "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞"
# 模型输出
current_length = 0
for response, history in model.stream_chat(tokenizer, prompt, history=[]):
    print(response[current_length:], end="", flush=True)
    current_length = len(response)
print("")

image

标签:6B,PT,--,ChatGLM2,steps,ChatGLM,output,model
From: https://www.cnblogs.com/wufengsheng/p/17735947.html

相关文章

  • 强推!!国产免费chatgpt,功能强大,速来体验!!!
    最近几个月,AIGC迅速崛起,周围的同学写代码、写ppt、写小红书都用上了各种AI工具。这两天讯飞的朋友给我推荐了他们的星火大模型。讯飞在基于自然语言处理领域积累了很多年的优势,拿过无数专利,因此朋友推荐给我星火大模型的时候我也是第一时间就注册申请使用了。使用后我是震惊了,太强......
  • JavaScript 中的类型、值和变量
     JavaScript类型可以分为两类:原始类型和对象类型。JavaScript的基本类型包括数字、文本字符串(称为字符串)和布尔真值(称为布尔值)。特殊的JavaScript值null和undefined是原始值,但它们不是数字、字符串或布尔值。每个值通常被认为是其自身特殊类型的唯一成员。ES6添加了一种新......
  • Chapter 1 自然地理
    atmospherehydrospherelithosphereoxygenoxidecarbondioxidehydrogencorecrustmantlelongtitudelatitudehorizonaltitudedisastermishapcatastrophiccalamityendangerjeopardisedestructiveEININOphenomenonpebblemagnetoremineralmarblequatz......
  • chapter 7 文件操作&chapter 8 使用系统调研进行文件操作
    chapter7文件操作&chapter8使用系统调研进行文件操作7.1文件操作文件操作由五个层次构成,从低到高,如下图所示。7.1.1硬件级别硬件级别的文件操作包括以下程序:fdisk:将硬盘、USB或SDC驱动器分成分区。mkfs:格式化磁盘分区以准备它们用于文件系统。fsck:检查和修复文件......
  • k8s持久化存储01 emptyDir hostPath
    本质上,K8svolume是一个目录,这点和Dockervolume差不多,当Volume被mount到Pod上,这个Pod中的所有容器都可以访问这个volume,常用的类型有这几种:emptyDirhostPathPersistentVolume(PV)&PersistentVolumeClaim(PVC)StorageClass01.emptyDiremptyDir是最基础的Volume类型,pod内的容器......
  • StaleElementReferenceException
    字面翻译是过时元素引用异常,通常是在获取元素之后,页面刷新/更新了所导致的。如,获取一个元素,然后页面刷新了,再使用text方法,这时就有这个异常解决办法,直接在获取时使用text方法。或者重新获取元素,然后再使用text方法......
  • JavaScript——小数精度丢失问题
    JavaScript小数进行数值运算时出现精度丢失问题1.原因:JavaScript的number类型在进行运算时都先将十进制转二进制,此时,小数点后面的数字转二进制时会出现无限循环的问题。为了避免这一个情况,要舍0进1,此时就会导致精度丢失问题。2.如何解决:(1)保留小数位数toFixed()constnumObj=......
  • 使用openssl_encrypt自己生成license.lic文件
     //生成加密文件publicfunctioncreateLicense(){//加密信息$licenseData=['user'=>'JohnDoe','expiry'=>'2022-12-31',];$licenseData=json_......
  • java.lang.IllegalStateException: javax.websocket.server.ServerContainer not avai
    spring项目能正常运行,但是单元测试报错错误原因注册WebSocket的Bean与springboot内带tomcat冲突解决办法1.注释该类里面的代码(不推荐)2.@springBootTest注解添加webEnvironment=SpringBootTest.WebEnvironment.RANDOM_PORT@SpringBootTest注解中,给出了webEnvironment参......
  • 戴尔OptiPlex 3020升级BIOS刷入NVME驱动
    前提:戴尔OptiPlex3020的主板是H81的,DELL官网的bios是不支持nvme启动的。我也是在外网找的,然后根据自己的情况刷的。目前电脑刷了后是可以直接选择nvme的ssd启动的。外网链接如下:https://www.tachytelic.net/2021/12/dell-optiplex-7020-nvme-ssd/简单的说下步骤:1、先从DELL官......