to 2024 / 04 / 22
部署环境
OS: Windows10, WSL2 ( Ubuntu 20.04 )
CPU: Intel(R) Core(TM) i5-12490F
GPU: GeForce RTX 4070Ti
部署过程
部署主要参考$[2]$,其中也遇到了一定的问题,记录如下:
模型下载
模型需要使用Git LFS工具进行下载,由于之前在Windows环境下已经下载过模型文件,且文件较大,直接在系统内进行复制而没有重复下载(具体可以参考$[3]$)。WindowsPowerShell下载指令:
git clone https://huggingface.co/THUDM/chatglm-6b
需要将如下对应文件复制到WSL2自己设定的文件路径下:
环境配置
使用conda (4.5.11) 创建环境,pip (23.3.1)配置环境,可以尝试直接在git的项目$[1]$路径下运行:
pip install -r requirements.txt
最开始下载时存在部分模块(e.g. PyYAML)版本不一致问题,可能是conda最开始初始化时导致的,如果按照所需的环境逐个下载,可以尝试使用以下指令强行更新版本(但是无法删除,参考$[4]$):
pip3 install --ignore-installed PyYAML
在之后运行模型时,可能遇到 'Textbox' object has no attribute 'style'
报错,可能是gradio模块版本过高导致的,可以尝试单独指定gradio版本(参考$[5]$):
pip uninstall gradio
pip install gradio==3.50.0
DEMO & API 尝试
项目本身提供了web和cli两个demo,但个人在使用web demo加载时会出现问题,考虑到项目有自己单独的前端,所以该问题未解决,cli demo可以正常运行,需要修改cli_demo.py
中的部分内容:
LOCAL_PATH = "/home/lyc/workspace/ChatGLM-6B"
tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH+"/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(LOCAL_PATH+"/chatglm-6b", trust_remote_code=True).half().cuda()
需要注意的是LOCAL_PATH
需要是绝对路径。在显存不足时可以进行量化:
# 按需修改,目前只支持 4/8 bit 量化
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda()
命令行运行结果如下:
基于 P-Tuning 微调 ChatGLM-6B
安装依赖,且需要确保transformers模块版本为4.27.1,尝试运行如下代码:
pip install rouge_chinese nltk jieba datasets
export WANDB_DISABLED=true
在最开始git的项目中,your_path/ChatGLM-6B/ptuning
路径下提供了P-Tuning的demo,需要修改如下内容:
其中蓝框内的cli_demo.py是因为自带的web_demo我无法运行,简单修改了最开始目录下的内容来运行经过微调后的模型的,cli_demo.sh用于启动cli_demo.py,两者内容如下:
# cli_demo.py
import os, sys
import platform
import signal
import torch
import transformers
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
set_seed,
)
from arguments import ModelArguments, DataTrainingArguments
import readline
# LOCAL_PATH = "/home/lyc/workspace/ChatGLM-6B"
# tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH+"/chatglm-6b", trust_remote_code=True)
# model = AutoModel.from_pretrained(LOCAL_PATH+"/chatglm-6b", trust_remote_code=True).half().cuda()
# model = model.eval()
model = None
tokenizer = None
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
def build_prompt(history):
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM-6B:{response}"
return prompt
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
def main():
global model, tokenizer
parser = HfArgumentParser((
ModelArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
model_args = parser.parse_args_into_dataclasses()[0]
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True)
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
if model_args.ptuning_checkpoint is not None:
print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}")
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
if model_args.quantization_bit is not None:
print(f"Quantized to {model_args.quantization_bit} bit")
model = model.quantize(model_args.quantization_bit)
if model_args.pre_seq_len is not None:
# P-tuning v2
model = model.half().cuda()
model.transformer.prefix_encoder.float().cuda()
model = model.eval()
history = []
global stop_stream
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
os.system(clear_command)
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
continue
count = 0
for response, history in model.stream_chat(tokenizer, query, history=history):
if stop_stream:
stop_stream = False
break
else:
count += 1
if count % 8 == 0:
os.system(clear_command)
print(build_prompt(history), flush=True)
signal.signal(signal.SIGINT, signal_handler)
os.system(clear_command)
print(build_prompt(history), flush=True)
if __name__ == "__main__":
main()
在cli_demo.sh中,model_name_or_path需要改为你最开始下载模型的位置,ptuning_checkpoint需要与train.sh中的内容相对应,不同的训练模型会保存在不同地方。
PRE_SEQ_LEN=32
CUDA_VISIBLE_DEVICES=0 python3 cli_demo.py \
--model_name_or_path /home/lyc/workspace/ChatGLM-6B/chatglm-6b \
--ptuning_checkpoint output/adgen-chatglm-6b-pt-32-2e-2/checkpoint-500 \
--pre_seq_len $PRE_SEQ_LEN
橘框中为测试数据和训练数据,以json格式进行存储,形如:
[
{"content": "xxx1", "summary": "yyy1"},
{"content": "xxx2", "summary": "yyy2"},
...
{"content": "xxx3", "summary": "yyy3"}
]
红框为训练和测试的脚本,可以参考$[2]$按需修改对应参数 。
其他问题
部分模块或模型下载可能需要代理,个人使用clash代理,WSL2中需要配置git和conda的代理,git可以参考$[6]$,conda可以在用户目录下修改 .condarc
文件,增添内容:
proxy_servers:
http: http://nameserver:port
https: https://nameserver:port
ssl_verify: false
其中nameserver
可以在路径 /etc/resolv.conf
中查看,port请参考clash中的设置,缺省为7890。
后续(本科项目实训)
在测试中,使用 5 条数据训练 500 epoch,损失函数基本收敛,验证准确率较高,但距离目标任务的实际使用还有一定的距离,面对不同的输入格式的鲁棒性不足,需要设计输出函数格式并自动生成更多的训练测试数据。
本地部署算力较为吃紧,可能需要在服务器上进行微调。模型API需要进一步熟悉,以方便后续的项目开发。
参考资料
[1] ChatGLM-6B: An Open Bilingual Dialogue Language Model | 开源双语对话语言模型
https://github.com/THUDM/ChatGLM-6B
[2] chatglm的微调有没有保姆式的教程?? - 树先生的回答 - 知乎
https://www.zhihu.com/question/595670355/answer/3015099216
[3] 安装 Git Large File Storage
https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
[4] [已解決] Cannot uninstall ‘PyYAML’.
https://clay-atlas.com/blog/2022/04/08/cannot-uninstall-pyyaml-distutils-installed-project/#google_vignette
[5] chatglm2-b部署报错问题‘Textbox‘ object has no attribute ‘style‘
https://blog.csdn.net/m0_54393918/article/details/134355019
[6] WSL2 访问 Clash 网络代理
https://jike.dev/posts/wsl2-access-clash-network-proxy/