GLM-4v-9B 源码解析(二)
.\chatglm4-finetune\basic_demo\trans_web_demo.py
# 该脚本使用 Gradio 创建 GLM-4-9B 模型的交互式网络演示
"""
This script creates an interactive web demo for the GLM-4-9B model using Gradio,
a Python library for building quick and easy UI components for machine learning models.
It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface,
allowing users to interact with the model through a chat-like interface.
"""
# 导入操作系统模块
import os
# 从 pathlib 导入 Path 类以处理路径
from pathlib import Path
# 导入 Thread 类以支持多线程
from threading import Thread
# 导入 Union 类型以支持类型注解
from typing import Union
# 导入 Gradio 库以构建用户界面
import gradio as gr
# 导入 PyTorch 库以支持深度学习模型
import torch
# 从 peft 导入相关模型类
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
# 从 transformers 导入所需的类
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer
)
# 定义模型类型的别名
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
# 定义分词器类型的别名
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
# 从环境变量获取模型路径,若不存在则使用默认路径
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
# 从环境变量获取分词器路径,若不存在则使用模型路径
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
# 定义解析路径的辅助函数
def _resolve_path(path: Union[str, Path]) -> Path:
# 扩展用户路径并解析为绝对路径
return Path(path).expanduser().resolve()
# 定义加载模型和分词器的函数
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
# 解析模型目录路径
model_dir = _resolve_path(model_dir)
# 检查是否存在适配器配置文件
if (model_dir / 'adapter_config.json').exists():
# 从预训练模型中加载适配器模型
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
# 获取基础模型名称
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
# 从预训练模型中加载普通模型
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
# 将模型目录设为分词器目录
tokenizer_dir = model_dir
# 从预训练分词器中加载分词器
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
)
# 返回模型和分词器的元组
return model, tokenizer
# 加载模型和分词器
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
# 定义基于特定标记停止的类
class StopOnTokens(StoppingCriteria):
# 定义调用方法以检查是否应停止
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# 获取结束标记 ID
stop_ids = model.config.eos_token_id
# 检查输入 ID 的最后一个是否为停止 ID
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
# 如果没有停止条件,返回 False
return False
# 定义预测函数
def predict(history, prompt, max_length, top_p, temperature):
# 创建停止条件实例
stop = StopOnTokens()
# 初始化消息列表
messages = []
# 如果提示存在,将其添加到消息中
if prompt:
messages.append({"role": "system", "content": prompt})
# 遍历历史消息
for idx, (user_msg, model_msg) in enumerate(history):
# 如果提示存在且是第一条消息,则跳过
if prompt and idx == 0:
continue
# 如果是最后一条消息且模型消息为空
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
# 如果用户消息存在,则添加到消息中
if user_msg:
messages.append({"role": "user", "content": user_msg})
# 如果模型消息存在,则添加到消息中
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
# 使用 tokenizer 将消息应用于聊天模板,生成模型输入
model_inputs = tokenizer.apply_chat_template(messages,
add_generation_prompt=True, # 添加生成提示
tokenize=True, # 启用标记化
return_tensors="pt").to(next(model.parameters()).device) # 转移到模型设备
# 创建一个文本迭代器流,用于实时生成文本
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
# 定义生成所需的参数
generate_kwargs = {
"input_ids": model_inputs, # 模型输入的标记ID
"streamer": streamer, # 指定文本流
"max_new_tokens": max_length, # 生成的最大新标记数量
"do_sample": True, # 启用采样以生成多样化输出
"top_p": top_p, # 限制采样的前p概率
"temperature": temperature, # 控制生成文本的随机性
"stopping_criteria": StoppingCriteriaList([stop]), # 设置停止标准
"repetition_penalty": 1.2, # 防止重复生成
"eos_token_id": model.config.eos_token_id, # 结束标记的ID
}
# 创建一个线程用于执行生成
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # 启动线程
# 迭代从流中接收的新标记
for new_token in streamer:
if new_token: # 如果新标记存在
history[-1][1] += new_token # 将新标记添加到历史记录的最后一个条目
yield history # 生成历史记录的当前状态
# 使用 Gradio 创建一个聊天应用的块
with gr.Blocks() as demo:
# 添加 HTML 标题,居中显示
gr.HTML("""<h1 align="center">GLM-4-9B Gradio Simple Chat Demo</h1>""")
# 初始化聊天机器人对象
chatbot = gr.Chatbot()
# 创建一个行布局
with gr.Row():
# 第一列,宽度比例为 3
with gr.Column(scale=3):
# 嵌套的列,宽度比例为 12
with gr.Column(scale=12):
# 用户输入框,隐藏标签,提示为 "Input..."
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
# 嵌套的列,最小宽度为 32,宽度比例为 1
with gr.Column(min_width=32, scale=1):
# 提交按钮,标签为 "Submit"
submitBtn = gr.Button("Submit")
# 第二列,宽度比例为 1
with gr.Column(scale=1):
# 提示输入框,隐藏标签,提示为 "Prompt"
prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
# 设置提示按钮,标签为 "Set Prompt"
pBtn = gr.Button("Set Prompt")
# 第三列,宽度比例为 1
with gr.Column(scale=1):
# 清除历史记录按钮,标签为 "Clear History"
emptyBtn = gr.Button("Clear History")
# 最大长度滑动条,范围为 0 到 32768,初始值为 8192
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
# Top P 滑动条,范围为 0 到 1,初始值为 0.8
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
# 温度滑动条,范围为 0.01 到 1,初始值为 0.6
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
# 用户输入处理函数,返回空字符串和更新的历史记录
def user(query, history):
return "", history + [[query, ""]]
# 设置提示处理函数,返回包含提示文本和成功消息的列表
def set_prompt(prompt_text):
return [[prompt_text, "成功设置prompt"]]
# 点击设置提示按钮时,调用 set_prompt 函数,输入为 prompt_input,输出为 chatbot
pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
# 点击提交按钮时,调用 user 函数,输入为 user_input 和 chatbot,输出为更新后的 user_input 和 chatbot,且不排队
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
# 在用户提交后调用 predict 函数,输入为 chatbot、prompt_input、max_length、top_p 和 temperature,输出为 chatbot
predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
)
# 点击清除历史记录按钮时,调用匿名函数返回 None,更新 chatbot 和 prompt_input,且不排队
emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
# 启用队列处理
demo.queue()
# 启动 Gradio 应用,设置服务器名称和端口,自动在浏览器中打开并共享
demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True)
.\chatglm4-finetune\basic_demo\trans_web_vision_demo.py
"""
# 该脚本创建一个 Gradio 演示,使用 glm-4v-9b 模型作为 Transformers 后端,允许用户通过 Gradio Web UI 与模型互动。
# 使用方法:
# - 运行脚本以启动 Gradio 服务器。
# - 通过 Web UI 与模型互动。
# 需求:
# - Gradio 包
# - 输入 `pip install gradio` 来安装 Gradio。
"""
# 导入必要的库
import os # 用于处理操作系统功能,如环境变量
import torch # PyTorch 库,用于深度学习
import gradio as gr # Gradio 库,用于创建 Web UI
from threading import Thread # 用于多线程处理
from transformers import ( # 从 transformers 库导入模型和 tokenizer
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer, AutoModel, BitsAndBytesConfig
)
from PIL import Image # 用于处理图像
import requests # 用于发送 HTTP 请求
from io import BytesIO # 用于在内存中处理字节流
# 从环境变量获取模型路径,默认值为 'THUDM/glm-4v-9b'
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
# 加载预训练的 tokenizer
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH, # 使用指定的模型路径
trust_remote_code=True, # 信任远程代码
encode_special_tokens=True # 编码特殊标记
)
# 加载预训练的模型
model = AutoModel.from_pretrained(
MODEL_PATH, # 使用指定的模型路径
trust_remote_code=True, # 信任远程代码
device_map="auto", # 自动选择设备
torch_dtype=torch.bfloat16 # 指定模型使用的浮点精度
).eval() # 将模型设置为评估模式
# 定义停止条件类
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id # 获取结束标记 ID
for stop_id in stop_ids: # 遍历所有结束标记 ID
if input_ids[0][-1] == stop_id: # 如果当前输入的最后一个 ID 是结束标记
return True # 返回 True 以停止生成
return False # 否则返回 False
# 定义获取图像的函数
def get_image(image_path=None, image_url=None):
if image_path: # 如果提供了本地图像路径
return Image.open(image_path).convert("RGB") # 打开图像并转换为 RGB 格式
elif image_url: # 如果提供了图像 URL
response = requests.get(image_url) # 发送 GET 请求获取图像
return Image.open(BytesIO(response.content)).convert("RGB") # 打开图像并转换为 RGB 格式
return None # 如果没有提供图像,则返回 None
# 定义聊天机器人的主函数
def chatbot(image_path=None, image_url=None, assistant_prompt=""):
image = get_image(image_path, image_url) # 获取图像
# 准备消息列表,包括助手的提示和用户的输入
messages = [
{"role": "assistant", "content": assistant_prompt}, # 助手消息
{"role": "user", "content": "", "image": image} # 用户消息
]
# 使用 tokenizer 将消息转换为模型输入格式
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True, # 添加生成提示
tokenize=True, # 启用标记化
return_tensors="pt", # 返回 PyTorch 张量
return_dict=True # 返回字典格式
).to(next(model.parameters()).device) # 移动到模型所在设备
# 创建文本流迭代器以进行实时生成
streamer = TextIteratorStreamer(
tokenizer=tokenizer, # 使用的 tokenizer
timeout=60, # 超时时间
skip_prompt=True, # 跳过提示
skip_special_tokens=True # 跳过特殊标记
)
# 设置生成的参数
generate_kwargs = {
**model_inputs, # 包含模型输入
"streamer": streamer, # 使用的文本流迭代器
"max_new_tokens": 1024, # 生成的最大新标记数量
"do_sample": True, # 启用采样
"top_p": 0.8, # 使用 Top-p 采样
"temperature": 0.6, # 控制输出的随机性
"stopping_criteria": StoppingCriteriaList([StopOnTokens()]), # 设置停止条件
"repetition_penalty": 1.2, # 重复惩罚
"eos_token_id": [151329, 151336, 151338], # 结束标记 ID 列表
}
# 启动生成线程
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # 启动线程
response = "" # 初始化响应字符串
for new_token in streamer: # 遍历生成的每个新标记
if new_token: # 如果有新标记
response += new_token # 将新标记添加到响应中
return image, response.strip() # 返回图像和去除空白的响应
# 使用 Gradio 创建演示界面
with gr.Blocks() as demo:
demo.title = "GLM-4V-9B Image Recognition Demo" # 设置演示的标题
demo.description = """ # 设置演示的描述
This demo uses the GLM-4V-9B model to got image infomation.
"""
# 创建一个水平排列的容器
with gr.Row():
# 创建一个垂直排列的容器
with gr.Column():
# 创建一个文件输入框,供用户上传高优先级的图像
image_path_input = gr.File(label="Upload Image (High-Priority)", type="filepath")
# 创建一个文本框,供用户输入低优先级的图像 URL
image_url_input = gr.Textbox(label="Image URL (Low-Priority)")
# 创建一个文本框,供用户输入助手提示,可以修改
assistant_prompt_input = gr.Textbox(label="Assistant Prompt (You Can Change It)", value="这是什么?")
# 创建一个提交按钮
submit_button = gr.Button("Submit")
# 另一个垂直排列的容器
with gr.Column():
# 创建一个文本框,用于显示 GLM-4V-9B 模型的响应
chatbot_output = gr.Textbox(label="GLM-4V-9B Model Response")
# 创建一个图像组件,用于显示图像预览
image_output = gr.Image(label="Image Preview")
# 为提交按钮设置点击事件,调用 chatbot 函数
submit_button.click(chatbot,
# 定义输入为三个组件:上传的图像路径、图像 URL 和助手提示
inputs=[image_path_input, image_url_input, assistant_prompt_input],
# 定义输出为两个组件:图像输出和聊天机器人的输出
outputs=[image_output, chatbot_output])
# 启动 demo 应用,指定服务器地址和端口
demo.launch(server_name="127.0.0.1", server_port=8911, inbrowser=True, share=False)
.\chatglm4-finetune\basic_demo\vllm_cli_demo.py
"""
这个脚本创建了一个命令行界面(CLI)示例,使用 vllm 后端和 glm-4-9b 模型,
允许用户通过命令行接口与模型进行交互。
用法:
- 运行脚本以启动 CLI 演示。
- 通过输入问题与模型进行交互,并接收回答。
注意:该脚本包含一个修改,以处理 Markdown 到纯文本的转换,
确保 CLI 界面正确显示格式化文本。
"""
# 导入时间模块
import time
# 导入异步编程模块
import asyncio
# 从 transformers 库导入 AutoTokenizer
from transformers import AutoTokenizer
# 从 vllm 库导入相关的采样参数和引擎
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
# 导入类型注解
from typing import List, Dict
# 从 vllm.lora.request 导入 LoRARequest
from vllm.lora.request import LoRARequest
# 定义模型路径
MODEL_PATH = 'THUDM/glm-4-9b-chat'
# 初始化 LoRA 路径为空
LORA_PATH = ''
# 定义加载模型和分词器的函数
def load_model_and_tokenizer(model_dir: str, enable_lora: bool):
# 创建异步引擎参数的实例
engine_args = AsyncEngineArgs(
model=model_dir, # 设置模型路径
tokenizer=model_dir, # 设置分词器路径
enable_lora=enable_lora, # 是否启用 LoRA
tensor_parallel_size=1, # 设置张量并行大小
dtype="bfloat16", # 设置数据类型
trust_remote_code=True, # 允许远程代码
gpu_memory_utilization=0.9, # GPU 内存利用率
enforce_eager=True, # 强制使用急切执行
worker_use_ray=True, # 使用 Ray 来处理工作
disable_log_requests=True # 禁用日志请求
# 如果遇见 OOM 现象,建议开启下述参数
# enable_chunked_prefill=True,
# max_num_batched_tokens=8192
)
# 从预训练模型加载分词器
tokenizer = AutoTokenizer.from_pretrained(
model_dir, # 模型路径
trust_remote_code=True, # 允许远程代码
encode_special_tokens=True # 编码特殊符号
)
# 从引擎参数创建异步 LLM 引擎
engine = AsyncLLMEngine.from_engine_args(engine_args)
# 返回引擎和分词器
return engine, tokenizer
# 初始化 LoRA 启用标志为 False
enable_lora = False
# 如果有 LoRA 路径,则启用 LoRA
if LORA_PATH:
enable_lora = True
# 加载模型和分词器
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH, enable_lora)
# 定义异步生成函数
async def vllm_gen(lora_path: str, enable_lora: bool, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
# 应用聊天模板处理输入消息
inputs = tokenizer.apply_chat_template(
messages, # 输入的消息
add_generation_prompt=True, # 添加生成提示
tokenize=False # 不进行标记化
)
# 定义采样参数的字典
params_dict = {
"n": 1, # 生成的响应数量
"best_of": 1, # 从中选择最佳响应
"presence_penalty": 1.0, # 存在惩罚
"frequency_penalty": 0.0, # 频率惩罚
"temperature": temperature, # 温度参数
"top_p": top_p, # 样本的累积概率阈值
"top_k": -1, # 前 K 个采样
"use_beam_search": False, # 不使用束搜索
"length_penalty": 1, # 长度惩罚
"early_stopping": False, # 不提前停止
"ignore_eos": False, # 不忽略结束符
"max_tokens": max_dec_len, # 最大生成长度
"logprobs": None, # 日志概率
"prompt_logprobs": None, # 提示日志概率
"skip_special_tokens": True, # 跳过特殊符号
}
# 创建采样参数实例
sampling_params = SamplingParams(**params_dict)
# 如果启用了 LoRA,则使用 LoRA 请求生成输出
if enable_lora:
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)):
# 生成输出文本
yield output.outputs[0].text
# 否则,直接生成输出
else:
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
# 生成输出文本
yield output.outputs[0].text
# 定义聊天的异步函数
async def chat():
# 初始化聊天历史记录
history = []
# 设置最大长度
max_length = 8192
# 设置 top_p 参数
top_p = 0.8
# 设置温度参数
temperature = 0.6
# 打印欢迎消息
print("欢迎来到 GLM-4-9B CLI 聊天。请在下面输入您的消息。")
# 无限循环,直到用户选择退出
while True:
# 提示用户输入
user_input = input("\nYou: ")
# 检查用户输入是否为退出命令
if user_input.lower() in ["exit", "quit"]:
break
# 将用户输入添加到历史记录中,初始助手回复为空
history.append([user_input, ""])
# 初始化消息列表
messages = []
# 遍历历史记录,构建消息列表
for idx, (user_msg, model_msg) in enumerate(history):
# 如果是最后一条用户消息且没有助手回复,则只添加用户消息
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
# 如果有用户消息,则添加到消息列表
if user_msg:
messages.append({"role": "user", "content": user_msg})
# 如果有助手回复,则添加到消息列表
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
# 打印助手的响应前缀
print("\nGLM-4: ", end="")
# 当前输出长度初始化为0
current_length = 0
# 初始化输出字符串
output = ""
# 异步生成助手的响应
async for output in vllm_gen(LORA_PATH, enable_lora, messages, top_p, temperature, max_length):
# 打印输出中从当前长度开始的新内容
print(output[current_length:], end="", flush=True)
# 更新当前输出长度
current_length = len(output)
# 更新历史记录中最后一条消息的助手回复
history[-1][1] = output
# 当脚本直接运行时,以下代码将被执行
if __name__ == "__main__":
# 使用 asyncio 运行 chat() 协程
asyncio.run(chat())
.\chatglm4-finetune\basic_demo\vllm_cli_vision_demo.py
# 该脚本创建一个 CLI 演示,使用 vllm 后端支持 glm-4v-9b 模型,
# 允许用户通过命令行界面与模型互动。
# 使用说明:
# - 运行脚本以启动 CLI 演示。
# - 输入问题与模型互动,获取响应。
# 注意:该脚本包含修改,以处理 markdown 到纯文本的转换,
# 确保 CLI 接口正确显示格式化文本。
"""
import time # 导入时间模块,用于时间相关功能
import asyncio # 导入异步模块,以支持异步编程
from PIL import Image # 从 PIL 库导入 Image 类,用于图像处理
from typing import List, Dict # 从 typing 导入 List 和 Dict 类型,用于类型注释
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine # 从 vllm 导入相关类和参数
MODEL_PATH = 'THUDM/glm-4v-9b' # 定义模型路径常量
# 定义函数以加载模型和分词器
def load_model_and_tokenizer(model_dir: str):
# 设置异步引擎参数
engine_args = AsyncEngineArgs(
model=model_dir, # 指定模型目录
tensor_parallel_size=1, # 设置张量并行大小
dtype="bfloat16", # 指定数据类型为 bfloat16
trust_remote_code=True, # 信任远程代码执行
gpu_memory_utilization=0.9, # 设置 GPU 内存利用率
enforce_eager=True, # 强制使用急切执行
worker_use_ray=True, # 启用 Ray 进行工作者管理
disable_log_requests=True, # 禁用日志请求
# 如果遇见 OOM 现象,建议开启下述参数
# enable_chunked_prefill=True, # 启用分块预填充
# max_num_batched_tokens=8192 # 设置最大批处理令牌数
)
# 从引擎参数创建异步 LLM 引擎
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine # 返回创建的引擎
# 调用函数以加载模型和分词器
engine = load_model_and_tokenizer(MODEL_PATH)
# 定义异步生成函数
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
inputs = messages[-1] # 获取消息列表中的最后一条消息作为输入
params_dict = {
"n": 1, # 设置生成数量为 1
"best_of": 1, # 设置最佳选择数量为 1
"presence_penalty": 1.0, # 设置出现惩罚为 1.0
"frequency_penalty": 0.0, # 设置频率惩罚为 0.0
"temperature": temperature, # 设置生成温度
"top_p": top_p, # 设置 top_p 参数
"top_k": -1, # 设置 top_k 参数为 -1,表示不使用
"use_beam_search": False, # 不使用束搜索
"length_penalty": 1, # 设置长度惩罚为 1
"early_stopping": False, # 不启用早停
"ignore_eos": False, # 不忽略结束标记
"max_tokens": max_dec_len, # 设置最大令牌数
"logprobs": None, # 日志概率设置为 None
"prompt_logprobs": None, # 提示日志概率设置为 None
"skip_special_tokens": True, # 跳过特殊令牌
"stop_token_ids" :[151329, 151336, 151338] # 设置停止令牌 ID
}
# 使用参数字典创建采样参数
sampling_params = SamplingParams(**params_dict)
# 异步生成输出
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
yield output.outputs[0].text # 生成输出文本
# 定义异步聊天函数
async def chat():
history = [] # 初始化聊天历史
max_length = 8192 # 设置最大长度为 8192
top_p = 0.8 # 设置 top_p 参数为 0.8
temperature = 0.6 # 设置温度参数为 0.6
image = None # 初始化图像变量
print("Welcome to the GLM-4v-9B CLI chat. Type your messages below.") # 输出欢迎信息
image_path = input("Image Path:") # 提示用户输入图像路径
try:
# 尝试打开并转换图像为 RGB 格式
image = Image.open(image_path).convert("RGB")
except:
# 捕获异常并提示用户路径无效,继续文本对话
print("Invalid image path. Continuing with text conversation.")
# 无限循环,直到用户选择退出
while True:
# 获取用户输入
user_input = input("\nYou: ")
# 检查用户输入是否为退出命令
if user_input.lower() in ["exit", "quit"]:
break
# 将用户输入添加到历史记录,初始化模型响应为空
history.append([user_input, ""])
# 初始化消息列表
messages = []
# 遍历历史记录中的消息
for idx, (user_msg, model_msg) in enumerate(history):
# 如果是最新的用户消息且没有模型响应,构造包含图像的消息
if idx == len(history) - 1 and not model_msg:
messages.append({
"prompt": user_msg,
"multi_modal_data": {
"image": image
},})
break
# 如果存在用户消息,添加到消息列表
if user_msg:
messages.append({"role": "user", "prompt": user_msg})
# 如果存在模型消息,添加到消息列表
if model_msg:
messages.append({"role": "assistant", "prompt": model_msg})
# 打印模型的响应,准备输出
print("\nGLM-4v: ", end="")
# 当前输出长度初始化为0
current_length = 0
# 初始化输出字符串
output = ""
# 异步生成模型输出
async for output in vllm_gen(messages, top_p, temperature, max_length):
# 输出当前生成的内容,保持在同一行
print(output[current_length:], end="", flush=True)
# 更新当前输出长度
current_length = len(output)
# 更新历史记录中最新消息的模型响应
history[-1][1] = output
# 当脚本被直接运行时,执行以下代码
if __name__ == "__main__":
# 启动异步事件循环并运行 chat 函数
asyncio.run(chat())
.\chatglm4-finetune\composite_demo\browser\src\browser.ts
import { JSDOM } from 'jsdom'; // 从 jsdom 库导入 JSDOM 类,用于创建 DOM 对象
import TurndownService from 'turndown'; // 从 turndown 库导入 TurndownService 类,用于将 HTML 转换为 Markdown
import config from './config'; // 导入配置文件中的配置
import { Message, ToolObservation } from './types'; // 导入 Message 和 ToolObservation 类型定义
import { logger, withTimeout } from './utils'; // 从 utils 模块导入 logger 和 withTimeout 工具函数
// 表示显示中的引用
interface Quote {
text: string; // 引用的文本内容
metadata: Metadata[]; // 引用的元数据数组
}
interface ActionResult {
contentType: string; // 内容类型
metadataList?: TetherQuoteMetadata[]; // 可选的元数据列表
metadata?: any; // 可选的元数据
roleMetadata: string; // 角色元数据
message: string; // 消息内容
}
// 表示要在最终答案中标记的元数据
interface Metadata {
type: string; // 元数据类型
title: string; // 元数据标题
url: string; // 元数据链接
lines: string[]; // 相关行的数组
}
interface TetherQuoteExtra {
cited_message_idx: number; // 引用的消息索引
evidence_text: string; // 证据文本
}
interface TetherQuoteMetadata {
type: string; // 元数据类型
title: string; // 元数据标题
url: string; // 元数据链接
text: string; // 元数据文本
pub_date?: string; // 可选的发布日期
extra?: TetherQuoteExtra; // 可选的附加信息
}
interface Citation {
citation_format_type: string; // 引用格式类型
start_ix: number; // 开始索引
end_ix: number; // 结束索引
metadata?: TetherQuoteMetadata; // 可选的元数据
invalid_reason?: string; // 可选的无效原因
}
interface PageState {
aCounter: number; // 链接计数器
imgCounter: number; // 图片计数器
url: URL; // 当前页面的 URL 对象
url_string: string; // 当前页面的 URL 字符串
hostname: string; // 当前页面的主机名
links: string[]; // 当前页面的链接数组
links_meta: TetherQuoteMetadata[]; // 链接的元数据数组
lines: string[]; // 当前页面的文本行
line_source: Record<string, Metadata>; // 行的元数据,键为字符串表示的区间
title?: string; // 可选的页面标题
}
interface BrowserState {
pageStack: PageState[]; // 页面状态栈
quoteCounter: number; // 引用计数器
quotes: Record<string, Quote>; // 引用记录,键为引用 ID
}
// 移除密集链接的函数,接受一个文档和比率阈值
function removeDenseLinks(document: Document, ratioThreshold: number = 0.5) {
// 移除导航元素
const navs = document.querySelectorAll('nav'); // 查询所有导航元素
navs.forEach(nav => { // 遍历每个导航元素
if (nav.parentNode) { // 如果有父节点
nav.parentNode.removeChild(nav); // 从父节点中移除导航元素
}
});
// 查询列表、div、span、表格和段落元素
const elements = document.querySelectorAll('ul, ol, div, span, nav, table, p'); // 查询相关元素
elements.forEach(element => { // 遍历每个元素
if (element === null) return; // 如果元素为 null,直接返回
const children = Array.from(element.childNodes); // 将子节点转换为数组
const links = element.querySelectorAll('a'); // 查询所有链接元素
if (children.length <= 1) return; // 如果子节点数量小于等于1,直接返回
const allText = element.textContent ? element.textContent.trim().replace(/\s+/g, '') : ''; // 获取元素文本内容并去除多余空格
const linksText = Array.from(links) // 将链接文本合并成一个字符串
.map(link => (link.textContent ? link.textContent.trim() : '')) // 处理每个链接的文本
.join('') // 合并为单个字符串
.replace(/\s+/g, ''); // 去除多余空格
if (allText.length === 0 || linksText.length === 0) return; // 如果没有文本内容或链接文本,直接返回
let ratio = linksText.length / allText.length; // 计算链接文本占总文本的比率
if (ratio > ratioThreshold && element.parentNode) { // 如果比率超过阈值且有父节点
element.parentNode.removeChild(element); // 从父节点中移除该元素
}
});
}
abstract class BaseBrowser {
public static toolName = 'browser' as const; // 定义工具名称为 'browser'
public description = 'BaseBrowser'; // 描述为 'BaseBrowser'
private turndownService = new TurndownService({ // 初始化 TurndownService 实例
headingStyle: 'atx', // 设置标题样式为 atx
});
private state: BrowserState; // 声明浏览器状态
private transform(dom: JSDOM): string { // 转换函数,接收 JSDOM 对象,返回字符串
let state = this.lastPageState(); // 获取最后一个页面状态
state.aCounter = 0; // 重置链接计数器
state.imgCounter = 0; // 重置图片计数器
state.links = []; // 清空链接数组
return this.turndownService.turndown(dom.window.document); // 将 DOM 文档转换为 Markdown
}
private formatPage(state: PageState): string { // 格式化页面函数,接收页面状态,返回字符串
// 将状态中的行合并成一个字符串,以换行符分隔
let formatted_lines = state.lines.join('\n');
// 如果标题存在,则格式化标题并添加换行符,否则为空字符串
let formatted_title = state.title ? `TITLE: ${state.title}\n\n` : '';
// 定义可见范围的格式化字符串
let formatted_range = `\nVisible: 0% - 100%`;
// 将标题、行和可见范围合并成一个完整的消息字符串
let formatted_message = formatted_title + formatted_lines + formatted_range;
// 返回格式化后的消息字符串
return formatted_message;
}
// 创建新的页面状态并返回
private newPageState(): PageState {
return {
// 初始化计数器 aCounter 为 0
aCounter: 0,
// 初始化计数器 imgCounter 为 0
imgCounter: 0,
// 创建新的 URL 对象,指向空白页面
url: new URL('about:blank'),
// 初始化 URL 字符串为 'about:blank'
url_string: 'about:blank',
// 初始化主机名为空字符串
hostname: '',
// 初始化标题为空字符串
title: '',
// 初始化链接数组为空
links: [],
// 初始化链接元数据数组为空
links_meta: [],
// 初始化行数组为空
lines: [],
// 初始化行源对象为空
line_source: {},
};
}
// 推送新的页面状态到状态栈并返回该状态
private pushPageState(): PageState {
// 调用 newPageState 创建一个新的页面状态
let state = this.newPageState();
// 将新状态推入状态栈
this.state.pageStack.push(state);
// 返回新创建的页面状态
return state;
}
// 获取状态栈中的最后一个页面状态
private lastPageState(): PageState {
// 如果状态栈为空,抛出错误
if (this.state.pageStack.length === 0) {
throw new Error('No page state');
}
// 返回状态栈中的最后一个页面状态
return this.state.pageStack[this.state.pageStack.length - 1];
}
// 格式化错误 URL,限制其长度
private formatErrorUrl(url: string): string {
// 定义截断限制为 80 个字符
let TRUNCATION_LIMIT = 80;
// 如果 URL 长度小于等于限制,直接返回该 URL
if (url.length <= TRUNCATION_LIMIT) {
return url;
}
// 如果 URL 超过限制,截断并返回格式化的字符串
return url.slice(0, TRUNCATION_LIMIT) + `... (URL truncated at ${TRUNCATION_LIMIT} chars)`;
}
// 定义一个包含异步搜索功能的对象
protected functions = {
// 异步搜索函数,接收查询字符串和最近几天的参数,默认值为 -1
search: async (query: string, recency_days: number = -1) => {
// 记录调试信息,显示正在搜索的内容
logger.debug(`Searching for: ${query}`);
// 创建 URL 查询参数对象,包含搜索查询
const search = new URLSearchParams({ q: query });
// 如果 recency_days 大于 0,添加相应的查询参数
recency_days > 0 && search.append('recency_days', recency_days.toString());
// 如果自定义配置 ID 存在,添加相应的查询参数
if (config.CUSTOM_CONFIG_ID) {
search.append('customconfig', config.CUSTOM_CONFIG_ID.toString());
},
# 定义一个打开 URL 的函数,接受一个字符串类型的 URL
open_url: (url: string) => {
# 记录调试信息,输出当前打开的 URL
logger.debug(`Opening ${url}`);
# 设置超时限制,并发起网络请求,获取响应文本
return withTimeout(
config.BROWSER_TIMEOUT,
fetch(url).then(res => res.text()),
)
# 处理请求响应,提取返回值和耗时
.then(async ({ value: res, time }) => {
try {
# 获取当前页面状态,并记录 URL 信息
const state = this.pushPageState();
state.url = new URL(url); # 创建 URL 对象
state.url_string = url; # 存储原始 URL 字符串
state.hostname = state.url.hostname; # 提取主机名
const html = res; # 保存响应的 HTML 内容
const dom = new JSDOM(html); # 将 HTML 内容解析为 DOM 对象
const title = dom.window.document.title; # 获取页面标题
const markdown = this.transform(dom); # 转换 DOM 为 Markdown 格式
state.title = title; # 保存标题到状态
# 移除第一行,因为它将作为标题
const lines = markdown.split('\n'); # 按行分割 Markdown 内容
lines.shift(); # 移除第一行
# 移除后续的空行
let i = 0;
while (i < lines.length - 1) {
if (lines[i].trim() === '' && lines[i + 1].trim() === '') {
lines.splice(i, 1); # 删除连续的空行
} else {
i++; # 移动到下一行
}
}
let page = lines.join('\n'); # 将处理后的行重新组合为字符串
# 第一个换行符不是错误
let text_result = `\nURL: ${url}\n${page}`; # 创建结果字符串,包含 URL 和页面内容
state.lines = text_result.split('\n'); # 将结果按行分割
# 所有行只来自一个来源
state.line_source = {}; # 初始化行来源对象
state.line_source[`0-${state.lines.length - 1}`] = {
type: 'webpage', # 设置行来源类型
title: title, # 保存页面标题
url: url, # 保存页面 URL
lines: state.lines, # 保存行内容
};
let message = this.formatPage(state); # 格式化页面状态为消息
const returnContentType = 'browser_result'; # 定义返回内容类型
return {
contentType: returnContentType, # 返回内容类型
roleMetadata: returnContentType, # 返回角色元数据
message, # 返回格式化消息
metadataList: state.links_meta, # 返回链接元数据
};
} catch (err) {
# 捕获解析错误,抛出新的错误信息
throw new Error(`parse error: ${err}`);
}
})
# 捕获请求错误并进行处理
.catch(err => {
logger.error(err.message); # 记录错误信息
if (err.code === 'ECONNABORTED') {
# 如果是超时错误,抛出超时信息
throw new Error(`Timeout while loading page w/ URL: ${url}`);
}
# 否则抛出加载失败信息
throw new Error(`Failed to load page w/ URL: ${url}`);
});
},
},
};
# 构造函数初始化状态
constructor() {
this.state = {
pageStack: [], # 页面栈初始化为空
quotes: {}, # 初始化引用对象为空
quoteCounter: 7, # 初始化引用计数器
};
# 移除 turndown 服务中的 script 和 style 标签
this.turndownService.remove('script');
this.turndownService.remove('style');
# 为 turndown 添加规则
// 为 'reference' 类型的链接添加解析规则
this.turndownService.addRule('reference', {
// 过滤函数,判断节点是否为符合条件的链接
filter: function (node, options: any): boolean {
return (
// 只有在使用内联样式时才返回 true
options.linkStyle === 'inlined' &&
// 节点必须是 'A' 标签
node.nodeName === 'A' &&
// 'href' 属性必须存在
node.getAttribute('href') !== undefined
);
},
// 替换函数,用于生成特定格式的链接
replacement: (content, node, options): string => {
// 获取当前页面状态的最新记录
let state = this.state.pageStack[this.state.pageStack.length - 1];
// 如果内容为空或节点没有 'getAttribute' 方法,则返回空字符串
if (!content || !('getAttribute' in node)) return '';
let href = undefined;
try {
// 确保节点具有 'getAttribute' 方法
if ('getAttribute' in node) {
// 从 'href' 属性中提取主机名
const hostname = new URL(node.getAttribute('href')!).hostname;
// 如果主机名与当前状态的主机名相同或不存在,则不附加主机名
if (hostname === state.hostname || !hostname) {
href = '';
} else {
// 否则,附加主机名
href = '†' + hostname;
}
}
} catch (e) {
// 捕获异常以避免显示错误的链接
href = '';
}
// 如果 href 仍然未定义,则返回空字符串
if (href === undefined) return '';
// 获取链接的完整 URL
const url = node.getAttribute('href')!;
// 查找当前链接在状态中的索引
let linkId = state.links.findIndex(link => link === url);
// 如果链接不存在,则为其分配新的 ID
if (linkId === -1) {
linkId = state.aCounter++;
// logger.debug(`New link[${linkId}]: ${url}`);
// 将新链接的元数据推入状态中
state.links_meta.push({
type: 'webpage',
title: node.textContent!,
url: href,
text: node.textContent!,
});
// 将新链接添加到状态链接数组中
state.links.push(url);
}
// 返回格式化的链接字符串
return `【${linkId}†${node.textContent}${href}】`;
},
});
// 为 'img' 标签添加解析规则
this.turndownService.addRule('img', {
// 过滤条件,指定过滤 'img' 标签
filter: 'img',
// 替换函数,用于生成特定格式的图像标记
replacement: (content, node, options): string => {
// 获取当前页面状态的最新记录
let state = this.state.pageStack[this.state.pageStack.length - 1];
// 返回格式化的图像标记字符串
return `[Image ${state.imgCounter++}]`;
},
});
// 为 'li' 标签添加解析规则,并调整缩进
this.turndownService.addRule('list', {
// 过滤条件,指定过滤 'li' 标签
filter: 'li',
// 替换函数,用于生成特定格式的列表项
replacement: function (content, node, options) {
// 清理内容的换行符
content = content
.replace(/^\n+/, '') // 移除开头的多余换行符
.replace(/\n+$/, '\n') // 将结尾的多余换行符替换为一个换行符
.replace(/\n/gm, '\n '); // 在每行前添加缩进
// 确定列表前缀符号
let prefix = options.bulletListMarker + ' ';
// 获取父节点,确保是列表
const parent = node.parentNode! as Element;
// 如果父节点是有序列表,计算索引
if (parent.nodeName === 'OL') {
const start = parent.getAttribute('start');
const index = Array.prototype.indexOf.call(parent.children, node);
// 根据列表的起始值调整前缀
prefix = (start ? Number(start) + index : index + 1) + '. ';
}
// 返回格式化的列表项字符串,处理换行
return ' ' + prefix + content + (node.nextSibling && !/\n$/.test(content) ? '\n' : '');
},
});
// 为 'strong' 和 'b' 标签添加解析规则,移除加粗效果
this.turndownService.addRule('emph', {
// 过滤条件,指定过滤 'strong' 和 'b' 标签
filter: ['strong', 'b'],
// 替换函数,返回原始内容
replacement: function (content, node, options) {
// 如果内容为空,则返回空字符串
if (!content.trim()) return '';
// 返回原始内容
return content;
},
});
}
// 定义抽象方法,用于处理每一行内容并返回一个或多个 ActionResult
abstract actionLine(content: string): Promise<ActionResult | ActionResult[]>;
// 异步方法,处理传入的内容并返回 ToolObservation 数组
async action(content: string): Promise<ToolObservation[]> {
// 将内容按行分割成数组
const lines = content.split('\n');
// 初始化结果数组,用于存储 ActionResult
let results: ActionResult[] = [];
// 遍历每一行
for (const line of lines) {
// 记录当前处理的行信息
logger.info(`Action line: ${line}`)
try {
// 调用 actionLine 方法处理当前行,并等待结果
const lineActionResult = await this.actionLine(line);
// 记录当前行的处理结果
logger.debug(`Action line result: ${JSON.stringify(lineActionResult, null, 2)}`);
// 检查结果是否为数组
if (Array.isArray(lineActionResult)) {
// 将数组结果合并到 results 中
results = results.concat(lineActionResult);
} else {
// 将单个结果添加到 results 中
results.push(lineActionResult);
}
} catch (err) {
// 定义错误内容类型
const returnContentType = 'system_error';
// 将错误信息封装到结果中
results.push({
contentType: returnContentType,
roleMetadata: returnContentType,
message: `Error when executing command ${line}\n${err}`,
metadata: {
failedCommand: line,
},
});
}
}
// 初始化观察结果数组
const observations: ToolObservation[] = [];
// 遍历每个 ActionResult 以生成 ToolObservation
for (const result of results) {
// 构建观察对象
const observation: ToolObservation = {
contentType: result.contentType,
result: result.message,
roleMetadata: result.roleMetadata,
metadata: result.metadata ?? {},
};
// 如果结果中有 metadataList,将其添加到观察对象的 metadata 中
if (result.metadataList) {
observation.metadata.metadata_list = result.metadataList;
}
// 将观察对象添加到观察结果数组中
observations.push(observation);
}
// 返回所有观察结果
return observations;
}
// 后处理方法,用于处理消息和元数据
postProcess(message: Message, metadata: any) {
// 正则模式,用于匹配引用内容
const quotePattern = /【(.+?)†(.*?)】/g;
// 获取消息内容
const content = message.content;
// 初始化匹配变量
let match;
// 初始化引用数组
let citations: Citation[] = [];
// 定义引用格式类型
const citation_format_type = 'tether_og';
// 当匹配到引文模式时循环处理
while ((match = quotePattern.exec(content))) {
// 记录当前匹配的引文
logger.debug(`Citation match: ${match[0]}`);
// 获取匹配的起始索引
const start_ix = match.index;
// 获取匹配的结束索引
const end_ix = match.index + match[0].length;
// 初始化无效原因为 undefined
let invalid_reason = undefined;
// 声明元数据变量,类型为 TetherQuoteMetadata
let metadata: TetherQuoteMetadata;
// 尝试块,处理引文解析
try {
// 解析被引用消息的索引
let cited_message_idx = parseInt(match[1]);
// 获取证据文本
let evidence_text = match[2];
// 从状态中获取引用内容
let quote = this.state.quotes[cited_message_idx.toString()];
// 如果引用未定义,记录无效原因
if (quote === undefined) {
invalid_reason = `'Referenced message ${cited_message_idx} in citation 【${cited_message_idx}†${evidence_text}】 is not a quote or tether browsing display.'`;
// 记录错误信息
logger.error(`Triggered citation error with quote undefined: ${invalid_reason}`);
// 将无效引文信息推入 citations 数组
citations.push({
citation_format_type,
start_ix,
end_ix,
invalid_reason,
});
} else {
// 定义额外信息
let extra: TetherQuoteExtra = {
cited_message_idx,
evidence_text,
};
// 获取引用的元数据
const quote_metadata = quote.metadata[0];
// 构造引文元数据对象
metadata = {
type: 'webpage',
title: quote_metadata.title,
url: quote_metadata.url,
text: quote_metadata.lines.join('\n'),
extra,
};
// 将有效引文信息推入 citations 数组
citations.push({
citation_format_type,
start_ix,
end_ix,
metadata,
});
}
} catch (err) {
// 记录异常信息
logger.error(`Triggered citation error: ${err}`);
// 记录无效原因为捕获的异常
invalid_reason = `Citation Error: ${err}`;
// 将无效引文信息推入 citations 数组
citations.push({
start_ix,
end_ix,
citation_format_type,
invalid_reason,
});
}
}
// 将引文数组添加到元数据中
metadata.citations = citations;
}
// 获取当前状态
getState() {
// 返回状态对象
return this.state;
}
} // 结束类或块的作用域
export class SimpleBrowser extends BaseBrowser { // 定义一个名为 SimpleBrowser 的类,继承自 BaseBrowser
public description = 'SimpleBrowser'; // 声明一个公开属性 description,值为 'SimpleBrowser'
constructor() { // 构造函数
super(); // 调用父类的构造函数
}
async actionLine(content: string): Promise<ActionResult | ActionResult[]> { // 异步方法 actionLine,接受一个字符串参数 content,返回 ActionResult 或 ActionResult 数组
const regex = /(\w+)\(([^)]*)\)/; // 正则表达式,用于匹配函数名和参数
const matches = content.match(regex); // 在 content 中查找匹配项
if (matches) { // 如果找到匹配项
const functionName = matches[1]; // 提取函数名
let args_string = matches[2]; // 提取参数字符串
if (functionName === 'mclick') { // 如果函数名为 'mclick'
args_string = args_string.trim().slice(1, -1); // 去除参数字符串的 '[' 和 ']'
}
const args = args_string.split(',').map(arg => arg.trim()); // 将参数字符串按逗号分割,并去除空格
let result; // 声明结果变量
switch (functionName) { // 根据函数名执行不同的逻辑
case 'search': // 如果函数名为 'search'
logger.debug(`SimpleBrowser action search ${args[0].slice(1, -1)}`); // 记录调试信息
const recency_days = /(^|\D)(\d+)($|\D)/.exec(args[1])?.[2] as undefined | `${number}`; // 提取 recency_days 参数
result = await this.functions.search( // 调用 functions 对象的 search 方法
args[0].slice(1, -1), // 去除查询字符串的引号
recency_days && Number(recency_days), // 如果 recency_days 存在,则转换为数字
);
break; // 结束 switch 语句
case 'open_url': // 如果函数名为 'open_url'
logger.debug(`SimpleBrowser action open_url ${args[0].slice(1, -1)}`); // 记录调试信息
result = await this.functions.open_url(args[0].slice(1, -1)); // 调用 functions 对象的 open_url 方法
break; // 结束 switch 语句
case 'mclick': // 如果函数名为 'mclick'
logger.debug(`SimpleBrowser action mclick ${args}`); // 记录调试信息
result = await this.functions.mclick(args.map(x => parseInt(x))); // 调用 functions 对象的 mclick 方法,传入解析后的参数
break; // 结束 switch 语句
default: // 如果没有匹配的函数名
throw new Error(`Parse Error: ${content}`); // 抛出解析错误
}
return result; // 返回结果
} else { // 如果没有找到匹配项
throw new Error('Parse Error'); // 抛出解析错误
}
}
}
if (require.main === module) { // 如果当前模块是主模块
(async () => { // 定义并立即执行一个异步函数
let browser = new SimpleBrowser(); // 实例化 SimpleBrowser 对象
let demo = async (action: string) => { // 定义一个异步函数 demo,接受一个字符串参数 action
logger.info(` ------ Begin of Action: ${action} ------`); // 记录操作开始信息
let results = await browser.action(action); // 调用 browser 对象的 action 方法,获取结果
for (const [idx, result] of results.entries()) { // 遍历结果数组
logger.info(`[Result ${idx}] contentType: ${result.contentType}`); // 记录结果的 contentType
logger.info(`[Result ${idx}] roleMetadata: ${result.roleMetadata}`); // 记录结果的 roleMetadata
logger.info(`[Result ${idx}] result: ${result.result}`); // 记录结果
logger.info(`[Result ${idx}] metadata: ${JSON.stringify(result.metadata, null, 2)}`); // 记录结果的 metadata
}
logger.info(` ------ End of Action: ${action} ------\n\n`); // 记录操作结束信息
};
await demo("search('Apple Latest News')"); // 执行搜索操作
await demo('mclick([0, 1, 5, 6])'); // 执行 mclick 操作
await demo('mclick([1, 999999])'); // 执行 mclick 操作,包含超出范围的索引
await demo("open_url('https://chatglm.cn')"); // 执行打开 URL 操作
await demo("search('zhipu latest News')"); // 执行搜索操作
await demo('mclick([0, 1, 5, 6])'); // 再次执行 mclick 操作
})(); // 结束立即执行的异步函数
}
.\chatglm4-finetune\composite_demo\browser\src\config.ts
# 导出一个默认的配置对象
export default {
# 设置日志级别为 'debug'
LOG_LEVEL: 'debug',
# 设置浏览器超时时间为 10000 毫秒
BROWSER_TIMEOUT: 10000,
# 设置 Bing 搜索 API 的 URL
BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0/custom/',
# 设置 Bing 搜索 API 的密钥
BING_SEARCH_API_KEY: 'YOUR_BING_SEARCH_API_KEY',
# 自定义配置 ID 的占位符,用户应在此处填入实际值
CUSTOM_CONFIG_ID : 'YOUR_CUSTOM_CONFIG_ID', //将您的Custom Configuration ID放在此处
# 设置主机地址为 'localhost'
HOST: 'localhost',
# 设置端口为 3000
PORT: 3000,
};
.\chatglm4-finetune\composite_demo\browser\src\server.ts
# 导入 express 和相关类型
import express, { Express, Request, Response } from 'express';
# 导入自定义浏览器类
import { SimpleBrowser } from './browser';
# 导入配置文件
import config from './config';
# 导入日志工具
import { logger } from './utils';
# 初始化一个记录会话历史的对象
const session_history: Record<string, SimpleBrowser> = {};
# 创建一个 Express 应用实例
const app: Express = express();
# 中间件,解析 JSON 格式的请求体
app.use(express.json());
# 定义 POST 请求的根路由
app.post('/', async (req: Request, res: Response) => {
# 从请求体中解构出 session_id 和 action
const {
session_id,
action,
}: {
session_id: string;
action: string;
} = req.body;
# 记录会话 ID 到日志
logger.info(`session_id: ${session_id}`);
# 记录动作到日志
logger.info(`action: ${action}`);
# 如果 session_history 中没有该 session_id,创建新的 SimpleBrowser 实例
if (!session_history[session_id]) {
session_history[session_id] = new SimpleBrowser();
}
# 获取对应 session_id 的浏览器实例
const browser = session_history[session_id];
try {
# 执行浏览器动作并返回 JSON 响应
res.json(await browser.action(action));
} catch (err) {
# 记录错误到日志
logger.error(err);
# 返回 400 状态码和错误信息
res.status(400).json(err);
}
})
# 处理 SIGINT 信号以优雅退出进程
process.on('SIGINT', () => {
process.exit(0);
});
# 处理未捕获的异常并记录到日志
process.on('uncaughtException', e => {
logger.error(e);
});
# 从配置中解构出主机和端口
const { HOST, PORT } = config;
# 创建一个自执行的异步函数以启动服务器
(async () => {
# 监听指定的端口和主机
app.listen(PORT, HOST, () => {
# 记录服务器启动信息
logger.info(`⚡️[server]: Server is running at http://${HOST}:${PORT}`);
try {
# 发送 "ready" 信号给进程
(<any>process).send('ready');
} catch (err) {}
});
})();
.\chatglm4-finetune\composite_demo\browser\src\types.ts
# 定义文件接口,包含文件的 ID、名称和大小
export interface File {
# 文件的唯一标识符
id: string;
# 文件的名称
name: string;
# 文件的大小,以字节为单位
size: number;
}
# 定义元数据接口,包含文件列表和引用字符串(可选)
export interface Metadata {
# 文件列表,类型为 File 数组(可选)
files?: File[];
# 引用字符串(可选)
reference?: string;
}
# 定义消息接口,表示用户、助手、系统或观察者的消息
export interface Message {
# 消息角色,限定为特定的字符串类型
role: 'user' | 'assistant' | 'system' | 'observation';
# 消息的元数据,类型为字符串
metadata: string;
# 消息的内容,类型为字符串
content: string;
# 请求元数据(可选)
request_metadata?: Metadata;
}
# 定义工具观察接口,描述工具执行结果
export interface ToolObservation {
# 内容类型,表示结果的 MIME 类型
contentType: string;
# 工具的执行结果
result: string;
# 可能的文本内容(可选)
text?: string;
# 观察者角色的元数据(可选)
roleMetadata?: string; // metadata for <|observation|>${metadata}
# 响应的元数据,类型为任意
metadata: any; // metadata for response
}
.\chatglm4-finetune\composite_demo\browser\src\utils.ts
# 导入 winston 日志库
import winston from 'winston';
# 导入配置文件
import config from './config';
# 定义 TimeoutError 类,继承自 Error
export class TimeoutError extends Error {}
# 获取日志级别配置
const logLevel = config.LOG_LEVEL;
# 创建一个 logger 实例,用于记录日志
export const logger = winston.createLogger({
# 设置日志级别
level: logLevel,
# 定义日志格式,包括颜色化和自定义输出
format: winston.format.combine(
winston.format.colorize(),
winston.format.printf(info => {
# 格式化日志信息输出,显示级别和消息
return `${info.level}: ${info.message}`;
}),
),
# 定义日志传输方式,这里使用控制台输出
transports: [new winston.transports.Console()],
});
# 在控制台输出当前日志级别
console.log('LOG_LEVEL', logLevel);
# 定义一个将高分辨率时间转换为毫秒的函数
export const parseHrtimeToMillisecond = (hrtime: [number, number]): number => {
# 将高分辨率时间转换为毫秒
return (hrtime[0] + hrtime[1] / 1e9) * 1000;
};
# 定义一个封装 Promise 的函数,用于返回其值和执行时间
export const promiseWithTime = <T>(
promise: Promise<T>
): Promise<{
value: T;
time: number;
}> => {
# 返回一个新的 Promise
return new Promise((resolve, reject) => {
# 记录开始时间
const startTime = process.hrtime();
# 处理传入的 Promise
promise
.then(value => {
# 成功时解析,返回值和执行时间
resolve({
value: value,
time: parseHrtimeToMillisecond(process.hrtime(startTime))
});
})
.catch(err => reject(err)); # 捕获错误并拒绝
});
};
# 定义一个带超时功能的 Promise 函数
export const withTimeout = <T>(
millis: number,
promise: Promise<T>
): Promise<{
value: T;
time: number;
}> => {
# 创建一个超时的 Promise
const timeout = new Promise<{ value: T; time: number }>((_, reject) =>
# 指定时间后拒绝 Promise,抛出 TimeoutError
setTimeout(() => reject(new TimeoutError()), millis)
);
# 竞争两个 Promise,哪个先完成就返回哪个
return Promise.race([promiseWithTime(promise), timeout]);
};
GLM-4-9B Web Demo
Read this in English
安装
我们建议通过 Conda 进行环境管理。
执行以下命令新建一个 conda 环境并安装所需依赖:
conda create -n glm-4-demo python=3.12
conda activate glm-4-demo
pip install -r requirements.txt
请注意,本项目需要 Python 3.10 或更高版本。
此外,使用 Code Interpreter 还需要安装 Jupyter 内核:
ipython kernel install --name glm-4-demo --user
您可以修改 ~/.local/share/jupyter/kernels/glm-4-demo/kernel.json
来改变 Jupyter 内核的配置,包括内核的启动参数等。例如,若您希望在使用 All Tools 的 Python 代码执行能力时使用 Matplotlib 画图,可以在 argv
数组中添加 "--matplotlib=inline"
。
若要使用浏览器和搜索功能,还需要启动浏览器后端。首先,根据 Node.js
官网的指示安装 Node.js,然后安装包管理器 PNPM 之后安装浏览器服务的依赖:
cd browser
npm install -g pnpm
pnpm install
运行
-
修改
browser/src/config.ts
中的BING_SEARCH_API_KEY
配置浏览器服务需要使用的 Bing 搜索 API Key:export default { BROWSER_TIMEOUT: 10000, BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0', BING_SEARCH_API_KEY: '<PUT_YOUR_BING_SEARCH_KEY_HERE>', HOST: 'localhost', PORT: 3000, };
如果您注册的是Bing Customer Search的API,您可以修改您的配置文件为如下,并且填写您的Custom Configuration ID:
export default { LOG_LEVEL: 'debug', BROWSER_TIMEOUT: 10000, BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0/custom/', BING_SEARCH_API_KEY: 'YOUR_BING_SEARCH_API_KEY', CUSTOM_CONFIG_ID : 'YOUR_CUSTOM_CONFIG_ID', //将您的Custom Configuration ID放在此处 HOST: 'localhost', PORT: 3000, };
-
文生图功能需要调用 CogView API。修改
src/tools/config.py
,提供文生图功能需要使用的 智谱 AI 开放平台 API Key:BROWSER_SERVER_URL = 'http://localhost:3000' IPYKERNEL = 'glm-4-demo' ZHIPU_AI_KEY = '<PUT_YOUR_ZHIPU_AI_KEY_HERE>' COGVIEW_MODEL = 'cogview-3'
-
启动浏览器后端,在单独的 shell 中:
cd browser pnpm start
-
运行以下命令在本地加载模型并启动 demo:
streamlit run src/main.py
之后即可从命令行中看到 demo 的地址,点击即可访问。初次访问需要下载并加载模型,可能需要花费一定时间。
如果已经在本地下载了模型,可以通过 export *_MODEL_PATH=/path/to/model
来指定从本地加载模型。可以指定的模型包括:
CHAT_MODEL_PATH
: 用于 All Tools 模式与文档解读模式,默认为THUDM/glm-4-9b-chat
。VLM_MODEL_PATH
: 用于 VLM 模式,默认为THUDM/glm-4v-9b
。
Chat 模型支持使用 vLLM 推理。若要使用,请安装 vLLM 并设置环境变量 USE_VLLM=1
。
Chat 模型支持使用 OpenAI API 推理。若要使用,请启动basic_demo目录下的openai_api_server并设置环境变量 USE_API=1
。该功能可以解耦推理服务器和demo服务器。
如果需要自定义 Jupyter 内核,可以通过 export IPYKERNEL=<kernel_name>
来指定。
使用
GLM-4 Demo 拥有三种模式:
- All Tools: 具有完整工具调用能力的对话模式,原生支持网页浏览、代码执行、图片生成,并支持自定义工具。
- 文档解读: 支持上传文档进行文档解读与对话。
- 多模态: 支持上传图像进行图像理解与对话。
All Tools
本模式兼容 ChatGLM3-6B 的工具注册流程。
- 代码能力,绘图能力,联网能力已经自动集成,用户只需按照要求配置对应的Key。
- 本模式下不支持系统提示词,模型会自动构建提示词。
对话模式下,用户可以直接在侧边栏修改 top_p, temperature 等参数来调整模型的行为。
与模型对话时,模型将会自主决定进行工具调用。
由于原始结果可能较长,默认情况下工具调用结果被隐藏,可以通过展开折叠框查看原始的工具调用结果。
模型拥有进行网页搜索和 Python 代码执行的能力。同时,模型也可以连续调用多个工具。例如:
此时模型通过调用浏览器工具进行搜索获取到了需要的数据,之后将会调用 Python 工具执行代码,利用 Matplotlib 绘图:
如果提供了智谱开放平台 API Key,模型也可以调用 CogView 进行图像生成:
自定义工具
可以通过在 tool_registry.py
中注册新的工具来增强模型的能力。只需要使用 @register_tool
装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数 docstring
即为工具的说明;对于工具的参数,使用 Annotated[typ: type, description: str, required: bool]
标注参数的类型、描述和是否必须。
例如,get_weather
工具的注册如下:
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the weather for `city_name` in the following week
"""
...
文档解读
用户可以上传文档,使用 GLM-4-9B的长文本能力,对文本进行理解。可以解析 pptx,docx,pdf等文件。
- 本模式下不支持工具调用和系统提示词。
- 如果文本很长,可能导致模型需要的显存较高,请确认你的硬件配置。
多模态
多模态模式下,用户可以利用 GLM-4V 的多模态理解能力,上传图像并与 GLM-4V 进行多轮对话:
用户可以上传图片,使用 GLM-4-9B的图像理解能力,对图片进行理解。
- 本模式必须使用 glm-4v-9b 模型。
- 本模式下不支持工具调用和系统提示词。
- 模型仅能对一张图片进行理解和联系对话,如需更换图片,需要开启一个新的对话。
- 图像支持的分辨率为 1120 x 1120
GLM-4-9B Web Demo
Installation
We recommend using Conda for environment management.
Execute the following commands to create a conda environment and install the required dependencies:
conda create -n glm-4-demo python=3.12
conda activate glm-4-demo
pip install -r requirements.txt
Please note that this project requires Python 3.10 or higher.
In addition, you need to install the Jupyter kernel to use the Code Interpreter:
ipython kernel install --name glm-4-demo --user
You can modify ~/.local/share/jupyter/kernels/glm-4-demo/kernel.json
to change the configuration of the Jupyter
kernel, including the kernel startup parameters. For example, if you want to use Matplotlib to draw when using the
Python code execution capability of All Tools, you can add "--matplotlib=inline"
to the argv
array.
To use the browser and search functions, you also need to start the browser backend. First, install Node.js according to
the instructions on the Node.js
official website, then install the package manager PNPM and then install the browser service
dependencies:
cd browser
npm install -g pnpm
pnpm install
Run
- Modify
BING_SEARCH_API_KEY
inbrowser/src/config.ts
to configure the Bing Search API Key that the browser service
needs to use:
export default {
BROWSER_TIMEOUT: 10000,
BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0',
BING_SEARCH_API_KEY: '<PUT_YOUR_BING_SEARCH_KEY_HERE>',
HOST: 'localhost',
PORT: 3000,
};
- The Wenshengtu function needs to call the CogView API. Modify
src/tools/config.py
, provide the Zhipu AI Open Platform API Key required for the Wenshengtu function:
BROWSER_SERVER_URL = 'http://localhost:3000'
IPYKERNEL = 'glm4-demo'
ZHIPU_AI_KEY = '<PUT_YOUR_ZHIPU_AI_KEY_HERE>'
COGVIEW_MODEL = 'cogview-3'
- Start the browser backend in a separate shell:
cd browser
pnpm start
- Run the following commands to load the model locally and start the demo:
streamlit run src/main.py
Then you can see the demo address from the command line and click it to access it. The first access requires downloading
and loading the model, which may take some time.
If you have downloaded the model locally, you can specify to load the model from the local
by export *_MODEL_PATH=/path/to/model
. The models that can be specified include:
-
CHAT_MODEL_PATH
: used for All Tools mode and document interpretation mode, the default isTHUDM/glm-4-9b-chat
. -
VLM_MODEL_PATH
: used for VLM mode, the default isTHUDM/glm-4v-9b
.
The Chat model supports reasoning using vLLM. To use it, please install vLLM and
set the environment variable USE_VLLM=1
.
The Chat model also supports reasoning using OpenAI API. To use it, please run openai_api_server.py
in basic_demo
and set the environment variable USE_API=1
. This function is used to deploy inference server and demo server in different machine.
If you need to customize the Jupyter kernel, you can specify it by export IPYKERNEL=<kernel_name>
.
Usage
GLM4 Demo has three modes:
- All Tools mode
- VLM mode
- Text interpretation mode
All Tools mode
You can enhance the model's capabilities by registering new tools in tool_registry.py
. Just use @register_tool
decorated function to complete the registration. For tool declarations, the function name is the name of the tool, and
the function docstring
is the description of the tool; for tool parameters, use Annotated[typ: type, description: str, required: bool]
to
annotate the parameter type, description, and whether it is required.
For example, the registration of the get_weather
tool is as follows:
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the weather for `city_name` in the following week
"""
...
This mode is compatible with the tool registration process of ChatGLM3-6B.
- Code capability, drawing capability, and networking capability have been automatically integrated. Users only need to
configure the corresponding Key as required. - System prompt words are not supported in this mode. The model will automatically build prompt words.
Text interpretation mode
Users can upload documents and use the long text capability of GLM-4-9B to understand the text. It can parse pptx, docx,
pdf and other files.
- Tool calls and system prompt words are not supported in this mode.
- If the text is very long, the model may require a high amount of GPU memory. Please confirm your hardware
configuration.
Image Understanding Mode
Users can upload images and use the image understanding capabilities of GLM-4-9B to understand the images.
- This mode must use the glm-4v-9b model.
- Tool calls and system prompts are not supported in this mode.
- The model can only understand and communicate with one image. If you need to change the image, you need to open a new
conversation. - The supported image resolution is 1120 x 1120
.\chatglm4-finetune\composite_demo\src\client.py
# 这是 composite_demo 的客户端部分
"""
# 提供两个客户端,HFClient 和 VLLMClient,用于与模型进行交互
We provide two clients, HFClient and VLLMClient, which are used to interact with the model.
# HFClient 用于与 transformers 后端交互,VLLMClient 用于与 VLLM 模型交互
The HFClient is used to interact with the transformers backend, and the VLLMClient is used to interact with the VLLM model.
"""
# 导入 JSON 模块用于处理 JSON 数据
import json
# 导入 Generator 类型用于类型注解
from collections.abc import Generator
# 导入 deepcopy 函数用于深拷贝对象
from copy import deepcopy
# 导入 Enum 和 auto 用于定义枚举类型
from enum import Enum, auto
# 导入 Protocol 类型用于定义协议
from typing import Protocol
# 导入 Streamlit 库以构建用户界面
import streamlit as st
# 从 conversation 模块导入 Conversation 类和 build_system_prompt 函数
from conversation import Conversation, build_system_prompt
# 从 tools.tool_registry 导入所有工具的注册列表
from tools.tool_registry import ALL_TOOLS
# 定义客户端类型的枚举
class ClientType(Enum):
# 定义 HF 类型
HF = auto()
# 定义 VLLM 类型
VLLM = auto()
# 定义 API 类型
API = auto()
# 定义客户端协议,包含初始化和生成流的方法
class Client(Protocol):
# 定义初始化方法,接受模型路径
def __init__(self, model_path: str): ...
# 定义生成流的方法,接受工具和历史记录
def generate_stream(
self,
tools: list[dict],
history: list[Conversation],
**parameters,
) -> Generator[tuple[str | dict, list[dict]]]: ...
# 处理输入数据的函数
def process_input(history: list[dict], tools: list[dict], role_name_replace:dict=None) -> list[dict]:
# 初始化聊天历史列表
chat_history = []
# 如果有工具,构建系统提示并添加到聊天历史
#if len(tools) > 0:
chat_history.append(
{"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)}
)
# 遍历历史对话
for conversation in history:
# 清理角色名称
role = str(conversation.role).removeprefix("<|").removesuffix("|>")
# 如果提供了角色替换字典,更新角色名称
if role_name_replace:
role = role_name_replace.get(role, role)
# 构建对话项
item = {
"role": role,
"content": conversation.content,
}
# 如果有元数据,添加到对话项
if conversation.metadata:
item["metadata"] = conversation.metadata
# 仅对用户角色添加图像
if role == "user" and conversation.image:
item["image"] = conversation.image
# 将对话项添加到聊天历史
chat_history.append(item)
# 返回聊天历史
return chat_history
# 处理响应数据的函数
def process_response(output, history):
# 初始化内容字符串
content = ""
# 深拷贝历史记录以避免修改原始数据
history = deepcopy(history)
# 分割输出,处理每个助手响应
for response in output.split("<|assistant|>"):
# 如果响应中有换行符
if "\n" in response:
# 分割元数据和内容
metadata, content = response.split("\n", maxsplit=1)
else:
# 如果没有换行,元数据为空,内容为响应
metadata, content = "", response
# 如果元数据为空,则处理内容
if not metadata.strip():
content = content.strip()
# 将助手的响应添加到历史记录
history.append({"role": "assistant", "metadata": metadata, "content": content})
# 替换特定文本
content = content.replace("[[训练时间]]", "2023年")
else:
# 否则,添加元数据和内容到历史记录
history.append({"role": "assistant", "metadata": metadata, "content": content})
# 如果历史记录的第一项是系统角色,并且包含工具
if history[0]["role"] == "system" and "tools" in history[0]:
# 解析内容为参数
parameters = json.loads(content)
content = {"name": metadata.strip(), "parameters": parameters}
else:
# 否则,将内容结构化
content = {"name": metadata.strip(), "content": content}
# 返回处理后的内容和历史记录
return content, history
# 缓存资源以提高性能,限制缓存条目数
@st.cache_resource(max_entries=1, show_spinner="Loading model...")
def get_client(model_path, typ: ClientType) -> Client:
# 根据传入的客户端类型决定使用哪个客户端
match typ:
# 匹配到 HF 类型时,导入 HFClient
case ClientType.HF:
from clients.hf import HFClient
# 返回 HFClient 实例,传入模型路径
return HFClient(model_path)
# 匹配到 VLLM 类型时,尝试导入 VLLMClient
case ClientType.VLLM:
try:
from clients.vllm import VLLMClient
# 捕获导入错误,并添加提示信息
except ImportError as e:
e.msg += "; did you forget to install vLLM?"
raise
# 返回 VLLMClient 实例,传入模型路径
return VLLMClient(model_path)
# 匹配到 API 类型时,导入 APIClient
case ClientType.API:
from clients.openai import APIClient
# 返回 APIClient 实例,传入模型路径
return APIClient(model_path)
# 如果没有匹配到支持的客户端类型,抛出未实现错误
raise NotImplementedError(f"Client type {typ} is not supported.")
.\chatglm4-finetune\composite_demo\src\clients\hf.py
"""
HuggingFace client. # HuggingFace 客户端的文档说明
"""
import threading # 导入 threading 模块以支持多线程
from collections.abc import Generator # 从 abc 模块导入 Generator 类型以定义生成器
from threading import Thread # 从 threading 模块导入 Thread 类以便于创建线程
import torch # 导入 PyTorch 库以进行张量操作
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer # 从 transformers 导入必要的类
from client import Client, process_input, process_response # 从 client 模块导入 Client 类和处理函数
from conversation import Conversation # 从 conversation 模块导入 Conversation 类
class HFClient(Client): # 定义 HFClient 类,继承自 Client 类
def __init__(self, model_path: str): # 构造函数,接收模型路径作为参数
# 使用预训练的模型路径初始化分词器,信任远程代码
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True,
)
# 使用预训练的模型路径初始化因果语言模型,信任远程代码,设置数据类型和设备映射
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16, # 使用 bfloat16 数据类型
device_map="cuda", # 将模型加载到 GPU
).eval() # 将模型设置为评估模式
def generate_stream( # 定义生成流的方法,接收工具、历史记录和可变参数
self,
tools: list[dict], # 工具列表,每个工具为字典
history: list[Conversation], # 对话历史记录列表
**parameters, # 其他参数
) -> Generator[tuple[str | dict, list[dict]]]: # 返回生成器,输出为字符串或字典的元组和字典列表
# 处理输入的对话历史和工具
chat_history = process_input(history, tools)
# 使用分词器将对话历史转化为模型输入格式
model_inputs = self.tokenizer.apply_chat_template(
chat_history,
add_generation_prompt=True, # 添加生成提示
tokenize=True, # 对输入进行分词
return_tensors="pt", # 返回 PyTorch 张量
return_dict=True, # 返回字典格式
).to(self.model.device) # 将模型输入移到模型的设备上
# 初始化文本迭代流处理器
streamer = TextIteratorStreamer(
tokenizer=self.tokenizer, # 使用分词器
timeout=5, # 设置超时时间为5秒
skip_prompt=True, # 跳过提示
)
# 准备生成参数,包括模型输入和其他参数
generate_kwargs = {
**model_inputs, # 解包模型输入
"streamer": streamer, # 添加流处理器
"eos_token_id": [151329, 151336, 151338], # 设置结束标记 ID
"do_sample": True, # 启用采样
}
generate_kwargs.update(parameters) # 更新生成参数,包含额外的可变参数
# 创建线程以生成文本
t = Thread(target=self.model.generate, kwargs=generate_kwargs) # 将生成方法作为线程目标
t.start() # 启动线程
total_text = "" # 初始化总文本字符串
for token_text in streamer: # 遍历生成的每个令牌文本
total_text += token_text # 将令牌文本追加到总文本中
# 生成并返回处理后的响应
yield process_response(total_text, chat_history)
.\chatglm4-finetune\composite_demo\src\clients\openai.py
"""
OpenAI API client. # 定义一个文档字符串,说明该模块是 OpenAI API 客户端
"""
from openai import OpenAI # 从 openai 模块导入 OpenAI 类
from collections.abc import Generator # 从 collections.abc 导入 Generator 类型
from client import Client, process_input, process_response # 从 client 模块导入 Client 类及处理输入和输出的函数
from conversation import Conversation # 从 conversation 模块导入 Conversation 类
def format_openai_tool(origin_tools): # 定义一个函数,将原始工具格式化为 OpenAI 工具
openai_tools = [] # 初始化一个空列表,用于存储格式化后的 OpenAI 工具
for tool in origin_tools: # 遍历每个原始工具
openai_param={} # 初始化一个空字典,用于存储工具参数
for param in tool['params']: # 遍历工具的参数
openai_param[param['name']] = {} # 将每个参数名称添加到字典中,值为空字典
openai_tool = { # 创建一个字典,表示格式化后的 OpenAI 工具
"type": "function", # 设置工具类型为函数
"function": { # 定义函数相关的信息
"name": tool['name'], # 设置函数名称
"description": tool['description'], # 设置函数描述
"parameters": { # 定义函数参数
"type": "object", # 设置参数类型为对象
"properties": { # 定义参数的属性
param['name']: {'type': param['type'], 'description': param['description']} for param in tool['params'] # 将每个参数名称及其类型和描述添加到属性中
},
"required": [param['name'] for param in tool['params'] if param['required']] # 获取必需参数的名称列表
}
}
}
openai_tools.append(openai_tool) # 将格式化后的工具添加到列表中
return openai_tools # 返回格式化后的 OpenAI 工具列表
class APIClient(Client): # 定义 APIClient 类,继承自 Client 类
def __init__(self, model_path: str): # 初始化方法,接受模型路径作为参数
base_url = "http://127.0.0.1:8000/v1/" # 定义基础 URL
self.client = OpenAI(api_key="EMPTY", base_url=base_url) # 创建 OpenAI 客户端实例,API 密钥设置为“EMPTY”
self.use_stream = False # 设置使用流的标志为 False
self.role_name_replace = {'observation': 'tool'} # 定义角色名称替换映射
def generate_stream( # 定义生成流的方法
self,
tools: list[dict], # 接受工具的列表
history: list[Conversation], # 接受对话历史的列表
**parameters, # 接受额外参数
) -> Generator[tuple[str | dict, list[dict]]]: # 返回生成器,类型为字符串或字典的元组和字典的列表
chat_history = process_input(history, '', role_name_replace=self.role_name_replace) # 处理输入历史,返回聊天历史
# messages = process_input(history, '', role_name_replace=self.role_name_replace) # 注释掉的代码,处理输入历史,返回消息(未使用)
openai_tools = format_openai_tool(tools) # 格式化工具列表
response = self.client.chat.completions.create( # 调用 OpenAI 客户端创建聊天补全请求
model="glm-4", # 指定使用的模型
messages=chat_history, # 提供聊天历史
tools=openai_tools, # 提供格式化后的工具
stream=self.use_stream, # 指定是否使用流
max_tokens=parameters["max_new_tokens"], # 设置最大生成的 tokens 数量
temperature=parameters["temperature"], # 设置生成的温度参数
presence_penalty=1.2, # 设置存在惩罚的值
top_p=parameters["top_p"], # 设置 top-p 采样的值
tool_choice="auto" # 设置工具选择为自动
)
output = response.choices[0].message # 获取响应中的第一个选择的消息
if output.tool_calls: # 检查输出是否包含工具调用
glm4_output = output.tool_calls[0].function.name + '\n' + output.tool_calls[0].function.arguments # 构建工具调用的输出
else: # 如果没有工具调用
glm4_output = output.content # 获取响应内容
yield process_response(glm4_output, chat_history) # 处理输出并生成结果
.\chatglm4-finetune\composite_demo\src\clients\vllm.py
"""
vLLM client. # vLLM 客户端的说明
Please install [vLLM](https://github.com/vllm-project/vllm) according to its
installation guide before running this client. # 提示用户在运行客户端前安装 vLLM
"""
import time # 导入时间模块,用于时间相关操作
from collections.abc import Generator # 从 collections 模块导入 Generator 类型
from transformers import AutoTokenizer # 从 transformers 模块导入自动分词器
from vllm import SamplingParams, LLMEngine, EngineArgs # 从 vllm 模块导入相关类
from client import Client, process_input, process_response # 从 client 模块导入 Client 类和处理函数
from conversation import Conversation # 从 conversation 模块导入 Conversation 类
class VLLMClient(Client): # 定义 VLLMClient 类,继承自 Client 类
def __init__(self, model_path: str): # 初始化方法,接收模型路径
self.tokenizer = AutoTokenizer.from_pretrained( # 创建分词器实例,从预训练模型加载
model_path, trust_remote_code=True # 指定模型路径,信任远程代码
)
self.engine_args = EngineArgs( # 创建引擎参数对象
model=model_path, # 设置模型路径
tensor_parallel_size=1, # 设置张量并行大小为 1
dtype="bfloat16", # 指定数据类型为 bfloat16,适用于高性能计算
trust_remote_code=True, # 信任远程代码
gpu_memory_utilization=0.6, # 设置 GPU 内存利用率为 60%
enforce_eager=True, # 强制使用即时执行
worker_use_ray=False, # 设置不使用 Ray 进行工作管理
)
self.engine = LLMEngine.from_engine_args(self.engine_args) # 从引擎参数创建 LLM 引擎实例
def generate_stream( # 定义生成流的方法
self, tools: list[dict], history: list[Conversation], **parameters # 接收工具列表、对话历史和其他参数
) -> Generator[tuple[str | dict, list[dict]]]: # 返回生成器,产生元组类型的输出
chat_history = process_input(history, tools) # 处理输入,将历史记录与工具结合
model_inputs = self.tokenizer.apply_chat_template( # 应用聊天模板生成模型输入
chat_history, add_generation_prompt=True, tokenize=False # 设置生成提示并禁用分词
)
parameters["max_tokens"] = parameters.pop("max_new_tokens") # 将 max_new_tokens 转换为 max_tokens
params_dict = { # 创建参数字典
"n": 1, # 设置生成样本数量为 1
"best_of": 1, # 设置最佳选择数量为 1
"top_p": 1, # 设置 nucleus 采样的阈值为 1
"top_k": -1, # 设置 top-k 采样为禁用状态
"use_beam_search": False, # 禁用束搜索
"length_penalty": 1, # 设置长度惩罚为 1
"early_stopping": False, # 禁用提前停止
"stop_token_ids": [151329, 151336, 151338], # 设置停止标记的 ID 列表
"ignore_eos": False, # 不忽略结束标记
"logprobs": None, # 不记录概率日志
"prompt_logprobs": None, # 不记录提示的概率日志
}
params_dict.update(parameters) # 更新参数字典,加入其他传入参数
sampling_params = SamplingParams(**params_dict) # 创建采样参数实例
self.engine.add_request( # 向引擎添加请求
request_id=str(time.time()), inputs=model_inputs, params=sampling_params # 设置请求 ID 和参数
)
while self.engine.has_unfinished_requests(): # 当引擎有未完成的请求时
request_outputs = self.engine.step() # 执行一步,获取请求输出
for request_output in request_outputs: # 遍历每个请求输出
yield process_response(request_output.outputs[0].text, chat_history) # 处理输出并生成响应
.\chatglm4-finetune\composite_demo\src\conversation.py
# 导入 JSON 模块,用于处理 JSON 数据
import json
# 导入正则表达式模块,用于字符串匹配
import re
# 从 dataclasses 模块导入 dataclass 装饰器,用于简化数据类的定义
from dataclasses import dataclass
# 从 datetime 模块导入 datetime 类,用于处理日期和时间
from datetime import datetime
# 从 enum 模块导入 Enum 和 auto,分别用于定义枚举和自动赋值
from enum import Enum, auto
# 导入 Streamlit 库,用于构建网页应用
import streamlit as st
# 从 Streamlit 的 delta_generator 模块导入 DeltaGenerator,用于动态内容生成
from streamlit.delta_generator import DeltaGenerator
# 从 PIL.Image 导入 Image 类,用于处理图像
from PIL.Image import Image
# 从 tools.browser 导入 Quote 和 quotes,用于处理引用和引用列表
from tools.browser import Quote, quotes
# 定义一个正则表达式,用于匹配特定格式的引用
QUOTE_REGEX = re.compile(r"【(\d+)†(.+?)】")
# 定义自我介绍提示,说明该助手的身份和任务
SELFCOG_PROMPT = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
# 定义日期提示的格式
DATE_PROMPT = "当前日期: %Y-%m-%d"
# 定义工具系统提示,包含不同工具的说明
TOOL_SYSTEM_PROMPTS = {
"python": "当你向 `python` 发送包含 Python 代码的消息时,该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出,或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中,`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用,这些在线内容的访问将不会成功。",
"simple_browser": "你可以使用 `simple_browser` 工具。该工具支持以下函数:\n`search(query: str, recency_days: int)`:使用搜索引擎进行查询并显示结果,可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`:获取一系列指定 id 的页面内容。每次调用时,须选择3-10个页面。选择多个角度的页面,同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的,你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`:打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤:1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL,也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息,也可以再次使用 `search` 进行搜索。",
"cogview": "如果用户的请求中包含了对图像的描述,你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述,规则:\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求,需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明,所在地为中国,持有中国立场并遵循中国社会主义价值观。",
}
# 定义文件模板字符串,用于生成文件内容
FILE_TEMPLATE = "[File Name]\n{file_name}\n[File Content]\n{file_content}"
# 定义构建系统提示的函数,接收可用工具和函数列表
def build_system_prompt(
enabled_tools: list[str], # 可用工具列表
functions: list[dict], # 函数列表
):
# 初始化提示内容为自我介绍
value = SELFCOG_PROMPT
# 将当前日期添加到提示内容中
value += "\n\n" + datetime.now().strftime(DATE_PROMPT)
# 如果有可用工具或函数,则添加提示标记
if enabled_tools or functions:
value += "\n\n# 可用工具"
# 初始化内容列表
contents = []
# 遍历每个可用工具,添加其描述
for tool in enabled_tools:
contents.append(f"\n\n## {tool}\n\n{TOOL_SYSTEM_PROMPTS[tool]}")
# 遍历每个函数,添加其描述和调用说明
for function in functions:
content = f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
contents.append(content)
# 将所有内容合并到提示中
value += "".join(contents)
# 返回构建的系统提示
return value
# 定义将响应转换为字符串的函数,支持字符串或字典类型
def response_to_str(response: str | dict[str, str]) -> str:
"""
将响应转换为字符串。
"""
# 如果响应是字典类型,则提取名称和内容
if isinstance(response, dict):
return response.get("name", "") + response.get("content", "")
# 如果响应是字符串类型,直接返回
return response
# 定义角色的枚举类,包含不同角色的定义
class Role(Enum):
SYSTEM = auto() # 系统角色
USER = auto() # 用户角色
ASSISTANT = auto() # 助手角色
TOOL = auto() # 工具角色
OBSERVATION = auto() # 观察角色
# 定义角色转换为字符串的方法
def __str__(self):
match self:
case Role.SYSTEM: # 如果是系统角色,返回对应字符串
return "<|system|>"
case Role.USER: # 如果是用户角色,返回对应字符串
return "<|user|>"
case Role.ASSISTANT | Role.TOOL: # 如果是助手或工具角色,返回对应字符串
return "<|assistant|>"
case Role.OBSERVATION: # 如果是观察角色,返回对应字符串
return "<|observation|>"
# 获取给定角色的消息块
# 定义获取消息的方法
def get_message(self):
# 由于 streamlit 的重跑行为,比较值而不是比较对象
# 因为会话状态中的枚举对象与此处的枚举情况不同
match self.value:
# 如果值是系统角色,则不返回任何内容
case Role.SYSTEM.value:
return
# 如果值是用户角色,则返回用户聊天消息
case Role.USER.value:
return st.chat_message(name="user", avatar="user")
# 如果值是助手角色,则返回助手聊天消息
case Role.ASSISTANT.value:
return st.chat_message(name="assistant", avatar="assistant")
# 如果值是工具角色,则返回工具聊天消息
case Role.TOOL.value:
return st.chat_message(name="tool", avatar="assistant")
# 如果值是观察角色,则返回观察聊天消息
case Role.OBSERVATION.value:
return st.chat_message(name="observation", avatar="assistant")
# 如果角色不匹配任何已知情况,则显示错误信息
case _:
st.error(f"Unexpected role: {self}")
# 定义一个数据类,用于表示对话内容
@dataclass
class Conversation:
# 对话的角色(如用户、助手等)
role: Role
# 对话的内容,可以是字符串或字典
content: str | dict
# 处理过的内容,默认为 None
saved_content: str | None = None
# 附加的元数据,默认为 None
metadata: str | None = None
# 附带的图像,默认为 None
image: str | Image | None = None
# 返回对话对象的字符串表示
def __str__(self) -> str:
# 如果有元数据则使用它,否则为空字符串
metadata_str = self.metadata if self.metadata else ""
# 格式化并返回角色和内容
return f"{self.role}{metadata_str}\n{self.content}"
# 返回人类可读的格式
def get_text(self) -> str:
# 使用保存的内容或原始内容
text = self.saved_content or self.content
# 根据角色类型决定文本格式
match self.role.value:
case Role.TOOL.value:
# 格式化工具调用的信息
text = f"Calling tool `{self.metadata}`:\n\n```py\n{text}\n```py"
case Role.OBSERVATION.value:
# 格式化观察结果的信息
text = f"```py\n{text}\n```py"
# 返回处理后的文本
return text
# 以 markdown 块的形式展示内容
def show(self, placeholder: DeltaGenerator | None = None) -> str:
# 使用占位符消息或角色消息
if placeholder:
message = placeholder
else:
message = self.role.get_message()
# 如果有图像,则添加图像
if self.image:
message.image(self.image, width=512)
# 如果角色为观察,则格式化消息
if self.role == Role.OBSERVATION:
metadata_str = f"from {self.metadata}" if self.metadata else ""
message = message.expander(f"Observation {metadata_str}")
# 获取文本内容
text = self.get_text()
# 根据角色决定展示的文本内容
if self.role != Role.USER:
show_text = text
else:
# 分割文本以处理上传的文件内容
splitted = text.split('files uploaded.\n')
if len(splitted) == 1:
show_text = text
else:
# 显示文档内容的扩展器
doc = splitted[0]
show_text = splitted[-1]
expander = message.expander(f'File Content')
expander.markdown(doc)
# 使用 markdown 格式展示最终文本
message.markdown(show_text)
# 后处理文本内容的函数
def postprocess_text(text: str, replace_quote: bool) -> str:
# 替换小括号为美元符号
text = text.replace("\(", "$")
# 替换小括号为美元符号
text = text.replace("\)", "$")
# 替换中括号为双美元符号
text = text.replace("\[", "$$")
# 替换中括号为双美元符号
text = text.replace("\]", "$$")
# 移除特定标签
text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "")
text = text.replace("<|user|>", "")
text = text.replace("<|endoftext|>", "")
# 如果需要替换引用
if replace_quote:
# 遍历找到的引用
for match in QUOTE_REGEX.finditer(text):
quote_id = match.group(1)
# 获取引用内容,如果未找到则使用默认信息
quote = quotes.get(quote_id, Quote("未找到引用内容", ""))
# 替换引用文本
text = text.replace(
match.group(0), f" (来源:[{quote.title}]({quote.url})) "
)
# 返回处理后的文本,去除前后空白
return text.strip()
.\chatglm4-finetune\composite_demo\src\main.py
# 这个文档演示 GLM-4 的所有工具和长上下文聊天能力
"""
This demo show the All tools and Long Context chat Capabilities of GLM-4.
Please follow the Readme.md to run the demo.
"""
# 导入操作系统模块
import os
# 导入 traceback 模块,用于调试时打印异常信息
import traceback
# 导入枚举类
from enum import Enum
# 导入字节流操作类
from io import BytesIO
# 导入生成唯一标识符的函数
from uuid import uuid4
# 导入 Streamlit 库,用于创建网页应用
import streamlit as st
# 从 Streamlit 的 delta_generator 模块导入 DeltaGenerator 类
from streamlit.delta_generator import DeltaGenerator
# 导入处理图像的库
from PIL import Image
# 导入客户端相关的类和函数
from client import Client, ClientType, get_client
# 从 conversation 模块导入相关的常量和类
from conversation import (
FILE_TEMPLATE,
Conversation,
Role,
postprocess_text,
response_to_str,
)
# 从工具注册模块导入调度工具和获取工具的函数
from tools.tool_registry import dispatch_tool, get_tools
# 导入文本提取相关的实用函数
from utils import extract_pdf, extract_docx, extract_pptx, extract_text
# 获取聊天模型路径,如果未设置则使用默认值
CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
# 获取多模态模型路径,如果未设置则使用默认值
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
# 判断是否使用 VLLM,根据环境变量进行设置
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
# 判断是否使用 API,根据环境变量进行设置
USE_API = os.environ.get("USE_API", "0") == "1"
# 定义模式枚举类
class Mode(str, Enum):
# 所有工具模式的标识
ALL_TOOLS = "
标签:4v,text,9B,content,state,源码,import,model,history
From: https://www.cnblogs.com/apachecn/p/18491988