首页 > 编程语言 >GLM-4v-9B-源码解析-三-

GLM-4v-9B-源码解析-三-

时间:2024-10-22 10:25:11浏览次数:5  
标签:4v self 9B ids 源码 output input model config

GLM-4v-9B 源码解析(三)

.\chatglm4-finetune\composite_demo\src\tools\browser.py

# 简单的浏览器工具说明
"""
Simple browser tool.

# Usage

Please start the backend browser server according to the instructions in the README.
"""

# 导入用于格式化输出的模块
from pprint import pprint
# 导入正则表达式模块
import re
# 导入请求模块,用于发送 HTTP 请求
import requests
# 导入 Streamlit 库,用于创建 Web 应用
import streamlit as st
# 导入数据类模块,用于定义数据结构
from dataclasses import dataclass

# 导入浏览器服务器的 URL 配置
from .config import BROWSER_SERVER_URL
# 导入工具观察接口
from .interface import ToolObservation

# 定义正则表达式用于匹配引用格式
QUOTE_REGEX = re.compile(r"\[(\d+)†(.+?)\]")

# 定义引用数据类,用于存储标题和 URL
@dataclass
class Quote:
    title: str
    url: str

# 检查会话状态中是否包含引用信息,如果没有则初始化为空字典
if "quotes" not in st.session_state:
    st.session_state.quotes = {}

# 获取会话状态中的引用字典
quotes: dict[str, Quote] = st.session_state.quotes

# 定义映射响应的函数,将响应转换为工具观察对象
def map_response(response: dict) -> ToolObservation:
    # 打印浏览器响应以供调试
    print('===BROWSER_RESPONSE===')
    pprint(response)
    # 获取角色元数据
    role_metadata = response.get("roleMetadata")
    # 获取其他元数据
    metadata = response.get("metadata")
    
    # 处理引用结果
    if role_metadata.split()[0] == 'quote_result' and metadata:
        # 提取引用 ID
        quote_id = QUOTE_REGEX.search(role_metadata.split()[1]).group(1)
        # 获取引用的元数据
        quote: dict[str, str] = metadata['metadata_list'][0]
        # 将引用添加到引用字典
        quotes[quote_id] = Quote(quote['title'], quote['url'])
    # 处理浏览器结果
    elif role_metadata == 'browser_result' and metadata:
        # 遍历元数据列表,将每个引用添加到字典
        for i, quote in enumerate(metadata['metadata_list']):
            quotes[str(i)] = Quote(quote['title'], quote['url'])

    # 返回工具观察对象,包含内容类型、文本、角色元数据和元数据
    return ToolObservation(
        content_type=response.get("contentType"),
        text=response.get("result"),
        role_metadata=role_metadata,
        metadata=metadata,
    )

# 定义工具调用函数,接受代码和会话 ID 作为参数
def tool_call(code: str, session_id: str) -> list[ToolObservation]:
    # 构建请求字典,包含会话 ID 和操作代码
    request = {
        "session_id": session_id,
        "action": code,
    }
    # 发送 POST 请求到浏览器服务器并获取响应
    response = requests.post(BROWSER_SERVER_URL, json=request).json()
    # 将响应映射为工具观察对象的列表并返回
    return list(map(map_response, response))

.\chatglm4-finetune\composite_demo\src\tools\cogview.py

# 导入 Streamlit 库,用于构建 Web 应用
import streamlit as st
# 导入 ZhipuAI 类,作为 AI 客户端
from zhipuai import ZhipuAI
# 从 ZhipuAI 库导入生成的图像类型
from zhipuai.types.image import GeneratedImage

# 从本地配置文件导入模型名称和 API 密钥
from .config import COGVIEW_MODEL, ZHIPU_AI_KEY
# 从本地接口模块导入工具观察类
from .interface import ToolObservation

# 使用 Streamlit 的缓存机制缓存 ZhipuAI 客户端
@st.cache_resource
def get_zhipu_client():
    # 创建并返回一个 ZhipuAI 客户端实例,使用 API 密钥
    return ZhipuAI(api_key=ZHIPU_AI_KEY)

# 定义映射响应的函数,接收生成的图像作为参数
def map_response(img: GeneratedImage):
    # 返回工具观察对象,包含图像的相关信息
    return ToolObservation(
        content_type='image',  # 设置内容类型为图像
        text='CogView 已经生成并向用户展示了生成的图片。',  # 返回的文本信息
        image_url=img.url,  # 图像的 URL 地址
        role_metadata='cogview_result'  # 角色元数据,标识来源
    )

# 定义工具调用函数,接收提示文本和会话 ID
def tool_call(prompt: str, session_id: str) -> list[ToolObservation]:
    # 获取 ZhipuAI 客户端实例
    client = get_zhipu_client()
    # 调用生成图像的 API,返回响应数据
    response = client.images.generations(model=COGVIEW_MODEL, prompt=prompt).data
    # 将响应映射为工具观察列表并返回
    return list(map(map_response, response))

.\chatglm4-finetune\composite_demo\src\tools\config.py

# 定义浏览器服务器的 URL 地址
BROWSER_SERVER_URL = 'http://localhost:3000'

# 定义使用的 IP 内核名称
IPYKERNEL = 'glm-4-demo'

# 定义 Zhipu AI 的密钥,默认为空字符串
ZHIPU_AI_KEY = ''

# 定义使用的 COGVIEW 模型名称
COGVIEW_MODEL = 'cogview-3'

.\chatglm4-finetune\composite_demo\src\tools\interface.py

# 从数据类模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从类型提示模块导入 Any 类型
from typing import Any

# 定义一个数据类 ToolObservation,自动生成初始化方法等
@dataclass
class ToolObservation:
    # 定义 content_type 属性,表示内容的类型
    content_type: str
    # 定义 text 属性,表示文本内容
    text: str
    # 定义 image_url 属性,表示图像的 URL,可以为 None
    image_url: str | None = None
    # 定义 role_metadata 属性,表示角色的元数据,可以为 None
    role_metadata: str | None = None
    # 定义 metadata 属性,表示其他元数据,可以是任何类型
    metadata: Any = None

.\chatglm4-finetune\composite_demo\src\tools\python.py

# 导入用于打印美化的模块
from pprint import pprint
# 导入队列模块
import queue
# 导入正则表达式模块
import re
# 导入子进程模块的管道功能
from subprocess import PIPE
# 导入字面量类型
from typing import Literal

# 导入 Jupyter 客户端模块
import jupyter_client
# 导入 Streamlit 模块
import streamlit as st

# 定义正则表达式用于匹配 ANSI 转义序列
ANSI_ESCAPE = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
# 定义正则表达式用于匹配代码块
CODE = re.compile(r'```([^\n]*)\n(.*?)```py')

# 定义 CodeKernel 类
class CodeKernel:
    # 初始化类的构造函数
    def __init__(self,
                 kernel_name='kernel',  # 设置内核名称,默认为 'kernel'
                 kernel_id=None,  # 可选内核 ID
                 kernel_config_path="",  # 内核配置文件路径
                 python_path=None,  # Python 路径
                 ipython_path=None,  # IPython 路径
                 init_file_path="./startup.py",  # 初始化文件路径
                 verbose=1):  # 是否打印详细信息

        # 初始化内核名称
        self.kernel_name = kernel_name
        # 初始化内核 ID
        self.kernel_id = kernel_id
        # 初始化内核配置文件路径
        self.kernel_config_path = kernel_config_path
        # 初始化 Python 路径
        self.python_path = python_path
        # 初始化 IPython 路径
        self.ipython_path = ipython_path
        # 初始化启动文件路径
        self.init_file_path = init_file_path
        # 初始化详细模式
        self.verbose = verbose

        # 如果没有提供 Python 和 IPython 路径,设置环境变量为 None
        if python_path is None and ipython_path is None:
            env = None
        else:
            # 设置环境变量包含 Python 路径
            env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}

        # 初始化后端内核管理器
        self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
                                                           connection_file=self.kernel_config_path,
                                                           exec_files=[self.init_file_path],
                                                           env=env)
        # 如果有配置文件路径,加载连接文件并启动内核
        if self.kernel_config_path:
            self.kernel_manager.load_connection_file()
            self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
            # 打印后端内核启动的信息
            print("Backend kernel started with the configuration: {}".format(
                self.kernel_config_path))
        else:
            # 否则直接启动内核
            self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
            # 打印后端内核启动的信息
            print("Backend kernel started with the configuration: {}".format(
                self.kernel_manager.connection_file))

        # 如果 verbose 为真,打印连接信息
        if verbose:
            pprint(self.kernel_manager.get_connection_info())

        # 初始化代码内核
        self.kernel = self.kernel_manager.blocking_client()
        # 启动内核通道
        self.kernel.start_channels()
        # 打印代码内核启动的信息
        print("Code kernel started.")
    # 定义执行代码的方法
        def execute(self, code):
            # 执行给定的代码
            self.kernel.execute(code)
            try:
                # 获取 shell 消息,最多等待 30 秒
                shell_msg = self.kernel.get_shell_msg(timeout=30)
                # 获取 IOPub 消息内容,最多等待 30 秒
                io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
                # 无限循环,直到执行状态变为 idle
                while True:
                    # 保存当前 IO 消息内容
                    msg_out = io_msg_content
                    ### 轮询消息
                    try:
                        # 获取新的 IOPub 消息内容,最多等待 30 秒
                        io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
                        # 如果执行状态为 idle,则退出循环
                        if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
                            break
                    except queue.Empty:
                        # 如果没有新消息,退出循环
                        break
    
                # 返回 shell 消息和最后的输出消息
                return shell_msg, msg_out
            except Exception as e:
                # 打印异常信息
                print(e)
                # 如果发生异常,返回 None
                return None
    
    # 定义交互式执行代码的方法
        def execute_interactive(self, code, verbose=False):
            # 交互式执行给定代码,获取 shell 消息
            shell_msg = self.kernel.execute_interactive(code)
            # 如果没有 shell 消息,则处理超时
            if shell_msg is queue.Empty:
                if verbose:
                    # 打印超时信息
                    print("Timeout waiting for shell message.")
            # 检查消息状态
            self.check_msg(shell_msg, verbose=verbose)
    
            # 返回 shell 消息
            return shell_msg
    
    # 定义检查代码的方法
        def inspect(self, code, verbose=False):
            # 发送代码检查请求,获取消息 ID
            msg_id = self.kernel.inspect(code)
            # 获取 shell 消息,最多等待 30 秒
            shell_msg = self.kernel.get_shell_msg(timeout=30)
            # 如果没有 shell 消息,则处理超时
            if shell_msg is queue.Empty:
                if verbose:
                    # 打印超时信息
                    print("Timeout waiting for shell message.")
            # 检查消息状态
            self.check_msg(shell_msg, verbose=verbose)
    
            # 返回 shell 消息
            return shell_msg
    
    # 定义获取错误消息的方法
        def get_error_msg(self, msg, verbose=False) -> str | None:
            # 检查消息状态是否为错误
            if msg['content']['status'] == 'error':
                try:
                    # 尝试获取完整的 traceback
                    error_msg = msg['content']['traceback']
                except:
                    try:
                        # 尝试获取最后一行的 traceback
                        error_msg = msg['content']['traceback'][-1].strip()
                    except:
                        # 如果都失败,返回默认错误信息
                        error_msg = "Traceback Error"
                if verbose:
                    # 打印错误信息
                    print("Error: ", error_msg)
                # 返回错误消息
                return error_msg
            # 如果没有错误,返回 None
            return None
    
    # 定义检查消息状态的方法
        def check_msg(self, msg, verbose=False):
            # 获取消息状态
            status = msg['content']['status']
            # 如果状态为 ok,表示执行成功
            if status == 'ok':
                if verbose:
                    # 打印执行成功信息
                    print("Execution succeeded.")
            # 如果状态为 error,打印 traceback
            elif status == 'error':
                for line in msg['content']['traceback']:
                    if verbose:
                        # 打印每行 traceback
                        print(line)
    
    # 定义关闭内核的方法
        def shutdown(self):
            # 关闭后端内核
            self.kernel_manager.shutdown_kernel()
            print("Backend kernel shutdown.")
            # 关闭代码内核
            self.kernel.shutdown()
            print("Code kernel shutdown.")
    
    # 定义重启内核的方法
        def restart(self):
            # 重启后端内核
            self.kernel_manager.restart_kernel()
            # print("Backend kernel restarted.")
    
    # 定义中断内核的方法
        def interrupt(self):
            # 中断后端内核
            self.kernel_manager.interrupt_kernel()
            # print("Backend kernel interrupted.")
    
    # 定义检查内核是否存活的方法
        def is_alive(self):
            # 返回内核存活状态
            return self.kernel.is_alive()
# 定义一个函数,用于清理输入字符串中的 ANSI 代码
def clean_ansi_codes(input_string):
    # 使用正则表达式去除输入字符串中的 ANSI 转义序列
    return ANSI_ESCAPE.sub('', input_string)

# 定义一个函数,从文本中提取代码段
def extract_code(text: str) -> str:
    # 查找文本中所有的代码段,返回匹配的列表
    matches = CODE.findall(text, re.DOTALL)
    # 返回最后一个匹配的代码段(假设代码段是元组,取第二个元素)
    return matches[-1][1]

# 定义一个执行代码的函数
def execute(
    code: str,
    kernel: CodeKernel
) -> tuple[Literal['text', 'image'] | None, str]:
    # 初始化结果和结果类型
    res = ""
    res_type = None
    # 清理代码中的特定 XML 标签
    code = code.replace("<|observation|>", "")
    code = code.replace("<|assistant|>python", "")
    code = code.replace("<|assistant|>", "")
    code = code.replace("<|user|>", "")
    code = code.replace("<|system|>", "")
    # 执行代码并获取消息和输出
    msg, output = kernel.execute(code)

    # 检查执行状态是否超时
    if msg['metadata']['status'] == "timeout":
        return res_type, 'Timed out'
    # 检查执行状态是否出错
    elif msg['metadata']['status'] == 'error':
        # 返回错误信息,清理 ANSI 代码
        return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))

    # 检查输出中是否包含文本
    if 'text' in output:
        res_type = "text"  # 设置结果类型为文本
        res = output['text']  # 获取文本结果
    # 检查输出中是否包含数据
    elif 'data' in output:
        # 遍历输出数据的每一个键
        for key in output['data']:
            # 如果数据类型是文本,设置结果类型和结果
            if 'text/plain' in key:
                res_type = "text"
                res = output['data'][key]
            # 如果数据类型是图片,设置结果类型和结果
            elif 'image/png' in key:
                res_type = "image"
                res = output['data'][key]
                break  # 找到图片后退出循环

    # 返回结果类型和结果
    return res_type, res

# 使用 Streamlit 的缓存机制定义一个获取内核的函数
@st.cache_resource
def get_kernel() -> CodeKernel:
    # 创建并返回一个新的 CodeKernel 实例
    return CodeKernel()

# 定义一个工具调用的函数
def tool_call(code: str, session_id: str) -> list[ToolObservation]:
    # 获取内核
    kernel = get_kernel()
    # 执行代码并获取结果类型和结果
    res_type, res = execute(code, kernel)

    # 根据结果类型转换为数据 URI
    text = '[Image]' if res_type == 'image' else res  # 如果是图片,设置文本为 '[Image]'
    image = f'data:image/png;base64,{res}' if res_type == 'image' else None  # 如果是图片,生成数据 URI

    # 返回包含工具观察结果的列表
    return [ToolObservation(res_type, text, image)]

.\chatglm4-finetune\composite_demo\src\tools\tool_registry.py

"""
该代码是工具注册部分。通过注册工具,模型可以调用该工具。
该代码为模型提供扩展功能,使其能够通过定义的接口调用和与各种工具交互。
"""

# 导入所需的模块和类
from collections.abc import Callable  # 导入 Callable 类型用于函数类型注解
import copy  # 导入 copy 模块用于对象复制
import inspect  # 导入 inspect 模块用于获取对象的信息
import json  # 导入 json 模块用于处理 JSON 数据
from pprint import pformat  # 导入 pformat 函数用于格式化输出
import traceback  # 导入 traceback 模块用于异常跟踪
from types import GenericAlias  # 导入 GenericAlias 类型用于泛型处理
from typing import get_origin, Annotated  # 导入类型相关工具
import subprocess  # 导入 subprocess 模块用于子进程管理

from .interface import ToolObservation  # 从当前包导入 ToolObservation 类

# 从不同模块导入工具调用
from .browser import tool_call as browser  # 导入浏览器工具调用
from .cogview import tool_call as cogview  # 导入 CogView 工具调用
from .python import tool_call as python  # 导入 Python 工具调用

# 定义所有可用工具的字典
ALL_TOOLS = {
    "simple_browser": browser,  # 将浏览器工具关联到其名称
    "python": python,  # 将 Python 工具关联到其名称
    "cogview": cogview,  # 将 CogView 工具关联到其名称
}

_TOOL_HOOKS = {}  # 初始化工具钩子字典,用于存储注册的工具
_TOOL_DESCRIPTIONS = []  # 初始化工具描述列表,用于存储工具信息


def register_tool(func: Callable):
    # 获取工具的名称
    tool_name = func.__name__
    # 获取工具的描述文档并去除首尾空格
    tool_description = inspect.getdoc(func).strip()
    # 获取工具参数的签名
    python_params = inspect.signature(func).parameters
    tool_params = []  # 初始化工具参数列表
    for name, param in python_params.items():  # 遍历每个参数
        annotation = param.annotation  # 获取参数的注解
        # 检查参数是否缺少类型注解
        if annotation is inspect.Parameter.empty:
            raise TypeError(f"Parameter `{name}` missing type annotation")
        # 检查注解类型是否为 Annotated
        if get_origin(annotation) != Annotated:
            raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")

        # 获取类型和描述、是否必需
        typ, (description, required) = annotation.__origin__, annotation.__metadata__
        typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__  # 确保类型为字符串
        # 检查描述是否为字符串
        if not isinstance(description, str):
            raise TypeError(f"Description for `{name}` must be a string")
        # 检查是否必需是否为布尔值
        if not isinstance(required, bool):
            raise TypeError(f"Required for `{name}` must be a bool")

        # 添加参数信息到工具参数列表
        tool_params.append(
            {
                "name": name,
                "description": description,
                "type": typ,
                "required": required,
            }
        )
    # 创建工具定义字典
    tool_def = {
        "name": tool_name,
        "description": tool_description,
        "params": tool_params,
    }
    # print("[registered tool] " + pformat(tool_def))  # 可选的调试输出
    _TOOL_HOOKS[tool_name] = func  # 将工具名称与函数绑定
    _TOOL_DESCRIPTIONS.append(tool_def)  # 将工具定义添加到描述列表

    return func  # 返回注册的工具函数


def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObservation]:
    # 分发预定义的工具
    if tool_name in ALL_TOOLS:
        return ALL_TOOLS[tool_name](code, session_id)  # 调用相应工具

    # 清理代码字符串
    code = code.strip().rstrip('<|observation|>').strip()

    # 分发自定义工具
    try:
        tool_params = json.loads(code)  # 尝试解析 JSON 格式的代码
    except json.JSONDecodeError as e:  # 捕获 JSON 解码错误
        err = f"Error decoding JSON: {e}"  # 创建错误信息
        return [ToolObservation("system_error", err)]  # 返回错误观察对象

    # 检查工具名称是否在已注册的工具中
    if tool_name not in _TOOL_HOOKS:
        err = f"Tool `{tool_name}` not found. Please use a provided tool."  # 错误信息
        return [ToolObservation("system_error", err)]  # 返回错误观察对象

    tool_hook = _TOOL_HOOKS[tool_name]  # 获取对应的工具钩子
    try:
        ret: str = tool_hook(**tool_params)  # 调用工具并传递参数
        return [ToolObservation(tool_name, str(ret))]  # 返回工具执行结果
    # 捕获异常,执行以下语句
    except:
        # 格式化当前异常的堆栈信息,保存到 err 变量
        err = traceback.format_exc()
        # 返回一个包含错误信息的 ToolObservation 对象的列表
        return [ToolObservation("system_error", err)]
# 获取工具的定义,返回工具描述的深拷贝列表
def get_tools() -> list[dict]:
    # 返回工具描述的深拷贝,避免修改原始数据
    return copy.deepcopy(_TOOL_DESCRIPTIONS)


# 工具定义部分


# 注册一个工具,生成随机数
@register_tool
def random_number_generator(
        # 随机生成器使用的种子,必须为整数
        seed: Annotated[int, "The random seed used by the generator", True],
        # 生成数值的范围,必须为整数元组
        range: Annotated[tuple[int, int], "The range of the generated numbers", True],
) -> int:
    """
    生成一个随机数 x,使得 range[0] <= x < range[1]
    """
    # 检查种子是否为整数
    if not isinstance(seed, int):
        raise TypeError("Seed must be an integer")
    # 检查范围是否为元组
    if not isinstance(range, tuple):
        raise TypeError("Range must be a tuple")
    # 检查范围的每个元素是否为整数
    if not isinstance(range[0], int) or not isinstance(range[1], int):
        raise TypeError("Range must be a tuple of integers")

    # 导入随机数模块
    import random

    # 根据种子创建随机数生成器,生成指定范围内的随机整数
    return random.Random(seed).randint(*range)


# 注册一个工具,获取天气信息
@register_tool
def get_weather(
        # 要查询的城市名称,必须为字符串
        city_name: Annotated[str, "The name of the city to be queried", True],
) -> str:
    """
    获取指定城市的当前天气
    """

    # 检查城市名称是否为字符串
    if not isinstance(city_name, str):
        raise TypeError("City name must be a string")

    # 定义需要获取的天气信息的键
    key_selection = {
        "current_condition": [
            "temp_C",
            "FeelsLikeC",
            "humidity",
            "weatherDesc",
            "observation_time",
        ],
    }
    # 导入请求模块
    import requests

    try:
        # 发起请求以获取天气数据
        resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
        # 检查请求是否成功
        resp.raise_for_status()
        # 解析返回的 JSON 数据
        resp = resp.json()
        # 构建返回的数据字典
        ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
    except:
        # 导入追踪模块以便于错误处理
        import traceback

        # 捕获异常,返回错误信息
        ret = (
                "Error encountered while fetching weather data!\n" + traceback.format_exc()
        )

    # 返回处理后的结果
    return str(ret)


# 注册一个工具,执行 Linux 命令
@register_tool
def get_shell(
        # 要在 Linux shell 中执行的命令,必须为字符串
        query: Annotated[str, "The command should run in Linux shell", True],
) -> str:
    """
    使用 shell 执行命令
    """
    # 检查命令是否为字符串
    if not isinstance(query, str):
        raise TypeError("Command must be a string")
    try:
        # 运行命令并捕获输出和错误信息
        result = subprocess.run(
            query,
            shell=True,
            check=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        # 返回命令的标准输出
        return result.stdout
    except subprocess.CalledProcessError as e:
        # 返回命令执行的错误信息
        return e.stderr


# 如果该文件是主程序
if __name__ == "__main__":
    # 测试执行 get_shell 工具
    # print(dispatch_tool("get_shell", {"query": "pwd"}))
    # 输出获取的工具列表
    print(get_tools())

.\chatglm4-finetune\composite_demo\src\utils.py

# 从 langchain_community 的文档加载器导入 PyMuPDFLoader
from langchain_community.document_loaders import PyMuPDFLoader
# 导入处理 Word 文档的库
import docx
# 导入处理 PowerPoint 演示文稿的库
from pptx import Presentation

# 定义提取文本的函数,接收文件路径作为参数
def extract_text(path):
    # 打开文件并读取其内容
    return open(path, 'r').read()

# 定义提取 PDF 内容的函数,接收文件路径作为参数
def extract_pdf(path):
    # 使用 PyMuPDFLoader 加载 PDF 文件
    loader = PyMuPDFLoader(path)
    # 从加载器中提取数据
    data = loader.load()
    # 提取每个页面的内容并存入列表
    data = [x.page_content for x in data]
    # 将所有页面内容合并为一个字符串
    content = '\n\n'.join(data)
    # 返回合并后的内容
    return content

# 定义提取 DOCX 内容的函数,接收文件路径作为参数
def extract_docx(path):
    # 使用 docx 库打开 DOCX 文件
    doc = docx.Document(path)
    # 初始化一个空列表以存储段落内容
    data = []
    # 遍历文档中的每个段落
    for paragraph in doc.paragraphs:
        # 将段落文本添加到列表中
        data.append(paragraph.text)
    # 将所有段落内容合并为一个字符串
    content = '\n\n'.join(data)
    # 返回合并后的内容
    return content

# 定义提取 PPTX 内容的函数,接收文件路径作为参数
def extract_pptx(path):
    # 使用 Presentation 类打开 PPTX 文件
    prs = Presentation(path)
    # 初始化一个空字符串以存储文本
    text = ""
    # 遍历每个幻灯片
    for slide in prs.slides:
        # 遍历幻灯片中的每个形状
        for shape in slide.shapes:
            # 检查形状是否包含文本属性
            if hasattr(shape, "text"):
                # 将形状的文本添加到字符串中,并换行
                text += shape.text + "\n"
    # 返回收集的文本
    return text

.\chatglm4-finetune\finetune_demo\finetune.py

# -*- coding: utf-8 -*-  # 指定文件编码为 UTF-8
import os  # 导入操作系统相关的模块
import jieba  # 导入中文分词库
import dataclasses as dc  # 导入数据类模块
import functools  # 导入用于高阶函数的工具
from collections.abc import Callable, Mapping, Sequence  # 导入集合相关的抽象基类
from pathlib import Path  # 导入处理路径的模块
from typing import Annotated, Any, Union  # 导入类型注解
import numpy as np  # 导入 NumPy 库
import ruamel.yaml as yaml  # 导入 YAML 处理库
import torch  # 导入 PyTorch 库
import typer  # 导入命令行界面库
from datasets import Dataset, Split  # 从 datasets 导入数据集相关类
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction  # 导入 BLEU 分数计算工具
from peft import PeftConfig, get_peft_config, get_peft_model  # 导入 PEFT 配置及模型获取函数
from rouge_chinese import Rouge  # 导入中文 ROUGE 评测工具
from torch import nn  # 导入 PyTorch 的神经网络模块
from transformers import (  # 导入变换器相关类和函数
    AutoModelForCausalLM,  # 自动加载因果语言模型
    AutoTokenizer,  # 自动加载分词器
    EvalPrediction,  # 导入评估预测结果的类
    GenerationConfig,  # 导入生成配置
    PreTrainedTokenizer,  # 导入预训练分词器
    Seq2SeqTrainingArguments,  # 导入序列到序列训练参数类
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq  # 导入序列到序列的数据整理器并重命名
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer  # 导入序列到序列训练器并重命名
from datasets import load_dataset, DatasetDict, NamedSplit  # 导入数据集加载和字典类
from typing import Optional  # 导入可选类型注解

app = typer.Typer(pretty_exceptions_show_locals=False)  # 创建命令行应用,禁用本地异常显示


class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):  # 定义数据整理器类
    def __call__(self, features, return_tensors=None):  # 重载调用方法,接受特征和返回张量选项
        output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)  # 获取输出 ID
        if output_ids is not None:  # 如果存在输出 ID
            max_output_length = max(len(out) for out in output_ids)  # 计算最大输出长度
            if self.pad_to_multiple_of is not None:  # 如果需要填充到特定倍数
                max_output_length = (  # 计算新的最大输出长度
                        (
                                max_output_length + self.pad_to_multiple_of - 1) //
                        self.pad_to_multiple_of * self.pad_to_multiple_of
                )
            for feature in features:  # 遍历特征
                remainder = [self.tokenizer.pad_token_id] * (  # 计算填充所需的剩余部分
                        max_output_length - len(feature['output_ids'])
                )
                if isinstance(feature['output_ids'], list):  # 如果输出 ID 是列表
                    feature['output_ids'] = feature['output_ids'] + remainder  # 追加填充
                else:  # 否则
                    feature['output_ids'] = np.concatenate(  # 将输出 ID 和填充合并
                        [feature['output_ids'], remainder]
                    ).astype(np.int64)  # 转换为整型数组
        return super().__call__(features, return_tensors)  # 调用父类方法返回结果


class Seq2SeqTrainer(_Seq2SeqTrainer):  # 定义序列到序列训练器类
    # Not Support for apex  # 不支持 apex

    def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:  # 定义训练步骤方法
        model.train()  # 设置模型为训练模式
        inputs = self._prepare_inputs(inputs)  # 准备输入数据

        with self.compute_loss_context_manager():  # 使用计算损失的上下文管理器
            loss = self.compute_loss(model, inputs)  # 计算损失

        if self.args.n_gpu > 1:  # 如果使用多个 GPU
            loss = loss.mean()  # 对损失进行平均
        self.accelerator.backward(loss)  # 反向传播损失
        detached_loss = loss.detach() / self.args.gradient_accumulation_steps  # 分离损失并归一化
        del inputs  # 删除输入以释放内存
        torch.cuda.empty_cache()  # 清空 CUDA 缓存
        return detached_loss  # 返回处理后的损失

    def prediction_step(  # 定义预测步骤方法
            self,
            model: nn.Module,  # 模型
            inputs: dict[str, Any],  # 输入数据
            prediction_loss_only: bool,  # 是否仅返回预测损失
            ignore_keys=None,  # 要忽略的键
            **gen_kwargs,  # 其他生成参数
    # 定义函数返回值类型为包含可选浮点数和两个可选的 Torch 张量的元组
    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
    
        # 禁用梯度计算,以节省内存和提高速度
        with torch.no_grad():  # Ensure no gradient computation
            # 如果设置为生成预测,则从输入中移除输出 ID
            if self.args.predict_with_generate:
                output_ids = inputs.pop('output_ids')
            # 从输入中获取输入 ID
            input_ids = inputs['input_ids']
    
            # 调用父类的方法执行预测步骤,获取损失、生成的标记和标签
            loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
            )
    
            # 截取生成的标记,去掉输入 ID 部分
            generated_tokens = generated_tokens[:, input_ids.size()[1]:]
            # 将标签设置为输出 ID
            labels = output_ids
    
            # 删除输入、输入 ID 和输出 ID,以释放内存
            del inputs, input_ids, output_ids
            # 清空 CUDA 缓存,避免内存溢出
            torch.cuda.empty_cache()
    
        # 返回损失、生成的标记和标签
        return loss, generated_tokens, labels
# 使用数据类装饰器定义一个数据配置类
@dc.dataclass
class DataConfig(object):
    # 训练数据文件路径,可选,默认为 None
    train_file: Optional[str] = None
    # 验证数据文件路径,可选,默认为 None
    val_file: Optional[str] = None
    # 测试数据文件路径,可选,默认为 None
    test_file: Optional[str] = None
    # 处理数据的进程数量,可选,默认为 None
    num_proc: Optional[int] = None

    # 定义一个属性,用于获取训练文件的格式后缀
    @property
    def data_format(self) -> str:
        # 返回训练文件路径的后缀
        return Path(self.train_file).suffix

    # 定义一个属性,用于获取文件路径的字典
    @property
    def data_files(self) -> dict[NamedSplit, str]:
        # 返回一个字典,包含各个数据分割及其对应的文件路径
        return {
            split: data_file
            # zip 函数将分割类型与文件路径配对
            for split, data_file in zip(
                [Split.TRAIN, Split.VALIDATION, Split.TEST],
                [self.train_file, self.val_file, self.test_file],
            )
            # 仅包含非 None 的文件路径
            if data_file is not None
        }


# 使用数据类装饰器定义一个微调配置类
@dc.dataclass
class FinetuningConfig(object):
    # 关联的数据配置
    data_config: DataConfig

    # 最大输入长度
    max_input_length: int
    # 最大输出长度
    max_output_length: int
    # 是否合并数据
    combine: bool
    # 是否冻结 V
    freezeV: bool

    # 定义训练参数,使用默认工厂函数生成对象
    training_args: Seq2SeqTrainingArguments = dc.field(
        default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
    )
    # 可选的 PEFT 配置
    peft_config: Optional[PeftConfig] = None

    # 后初始化方法,调整训练参数
    def __post_init__(self):
        # 如果不进行评估或验证文件为 None
        if not self.training_args.do_eval or self.data_config.val_file is None:
            # 设置不进行评估
            self.training_args.do_eval = False
            # 评估策略设置为 'no'
            self.training_args.evaluation_strategy = 'no'
            # 清空验证文件路径
            self.data_config.val_file = None
        else:
            # 设置评估批次大小
            self.training_args.per_device_eval_batch_size = (
                    self.training_args.per_device_eval_batch_size
                    or self.training_args.per_device_train_batch_size
            )

    # 从字典创建类的类方法
    @classmethod
    def from_dict(cls, **kwargs) -> 'FinetuningConfig':
        # 获取训练参数
        training_args = kwargs.get('training_args', None)
        # 如果训练参数存在且不是 Seq2SeqTrainingArguments 类型
        if training_args is not None and not isinstance(
                training_args, Seq2SeqTrainingArguments
        ):
            # 获取生成配置
            gen_config = training_args.get('generation_config')
            # 如果生成配置不是 GenerationConfig 类型
            if not isinstance(gen_config, GenerationConfig):
                # 创建生成配置并赋值
                training_args['generation_config'] = GenerationConfig(
                    **gen_config
                )
            # 更新训练参数为 Seq2SeqTrainingArguments 类型
            kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)

        # 获取数据配置
        data_config = kwargs.get('data_config')
        # 如果数据配置不是 DataConfig 类型
        if not isinstance(data_config, DataConfig):
            # 更新为 DataConfig 类型
            kwargs['data_config'] = DataConfig(**data_config)

        # 获取 PEFT 配置
        peft_config = kwargs.get('peft_config', None)
        # 如果 PEFT 配置存在且不是 PeftConfig 类型
        if peft_config is not None and not isinstance(peft_config, PeftConfig):
            # 获取 PEFT 配置并赋值
            kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
        # 返回新的类实例
        return cls(**kwargs)

    # 从文件创建类的类方法
    @classmethod
    def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
        # 将路径转换为 Path 对象
        path = Path(path)
        # 创建 YAML 解析器
        parser = yaml.YAML(typ='safe', pure=True)
        # 设置缩进格式
        parser.indent(mapping=2, offset=2, sequence=4)
        # 设置默认的流样式为 False
        parser.default_flow_style = False
        # 从文件加载内容
        kwargs = parser.load(path)
        # 从字典创建类实例并返回
        return cls.from_dict(**kwargs)


# 定义一个加载数据集的函数
def _load_datasets(
        # 数据目录
        data_dir: str,
        # 数据格式
        data_format: str,
        # 数据文件字典
        data_files: dict[NamedSplit, str],
        # 进程数量
        num_proc: Optional[int],
) -> DatasetDict:
    # 检查数据格式是否为 '.jsonl'
    if data_format == '.jsonl':
        # 加载数据集,指定数据目录、数据文件、拆分方式和并行处理进程数
        dataset_dct = load_dataset(
            data_dir,  # 数据存储目录
            data_files=data_files,  # 要加载的数据文件列表
            split=None,  # 不指定拆分,加载全部数据
            num_proc=num_proc,  # 指定并行处理的进程数
        )
    else:
        # 如果数据格式不被支持,抛出未实现错误并提示格式
        raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
    # 返回加载的数据集字典
    return dataset_dct
# 数据管理类,用于管理数据集
class DataManager(object):
    # 初始化方法,接受数据目录和数据配置对象
    def __init__(self, data_dir: str, data_config: DataConfig):
        # 从数据配置中获取并存储处理进程的数量
        self._num_proc = data_config.num_proc

        # 加载数据集并存储为字典,键为数据集分割,值为数据集
        self._dataset_dct = _load_datasets(
            data_dir,
            data_config.data_format,
            data_config.data_files,
            self._num_proc,
        )

    # 根据分割获取对应的数据集
    def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
        # 从数据集字典中获取指定分割的数据集,若不存在则返回 None
        return self._dataset_dct.get(split, None)

    # 获取指定分割的数据集,并处理数据
    def get_dataset(
            self,
            split: NamedSplit,
            process_fn: Callable[[dict[str, Any]], dict[str, Any]],
            batched: bool = True,
            remove_orig_columns: bool = True,
    ) -> Optional[Dataset]:
        # 获取原始数据集
        orig_dataset = self._get_dataset(split)
        # 若原始数据集不存在,则返回 None
        if orig_dataset is None:
            return

        # 根据标志决定是否移除原始列
        if remove_orig_columns:
            remove_columns = orig_dataset.column_names
        else:
            remove_columns = None
        # 调用 map 方法处理数据集并返回结果
        return orig_dataset.map(
            process_fn,
            batched=batched,
            remove_columns=remove_columns,
            num_proc=self._num_proc,
        )


# 处理消息函数
def process_message(message):
    # 如果消息中包含工具且角色为系统,则处理工具参数
    if 'tools' in message and message['role'] == 'system':
        for tool in message['tools']:
            # 获取工具的参数属性
            parameters = tool['function']['parameters']['properties']
            # 过滤掉参数值为 None 的属性
            tool['function']['parameters']['properties'] = \
                {k: v for k, v in parameters.items() if
                 v is not None}
    # 如果消息中包含工具,但角色不是系统,则删除工具
    elif 'tools' in message:
        del message['tools']
    # 返回处理后的消息
    return message


# 处理批次消息的函数
def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
        combine: bool,
) -> dict[str, list]:
    # 从批次中提取消息
    batched_conv = batch['messages']
    # 初始化存储输入 ID 的列表
    batched_input_ids = []
    # 初始化存储标签的列表
    batched_labels = []
    # 遍历分批的对话
        for conv in batched_conv:
            # 初始化输入 ID 列表
            input_ids = [151331, 151333]
            # 初始化损失掩码列表
            loss_masks = [False, False]
            # 如果需要合并对话
            if combine:
                # 应用聊天模板将对话转换为新的输入 ID
                new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
                # 更新输入 ID 列表
                input_ids = new_input_ids
                # 创建新的损失掩码,所有元素初始为 False
                loss_masks = [False] * len(input_ids)
                # 找到最后一个助手的索引
                last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
                # 为最后助手之后的输入设置掩码为 True
                for j in range(last_assistant_index + 1, len(input_ids)):
                    loss_masks[j] = True
            else:
                # 如果不合并,则处理每条消息
                for message in conv:
                    # 处理消息,提取有效信息
                    message = process_message(message)
                    # 确定损失掩码的值,根据角色决定
                    loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
                    # 应用聊天模板并更新输入 ID 列表,跳过前两个元素
                    new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                    # 将新的输入 ID 添加到输入 ID 列表
                    input_ids += new_input_ids
                    # 根据新的输入 ID 更新损失掩码
                    loss_masks += [loss_mask_val] * len(new_input_ids)
    
            # 在输入 ID 列表末尾添加结束符
            input_ids.append(151336)  # EOS for chat
            # 在损失掩码列表前添加一个 False
            loss_masks = [False, *loss_masks]
            # 初始化标签列表
            labels = []
            # 根据输入 ID 和损失掩码生成标签
            for input_id, mask in zip(input_ids, loss_masks):
                if mask:
                    labels.append(input_id)  # 如果掩码为 True,添加输入 ID
                else:
                    labels.append(-100)  # 否则添加 -100 作为无效标签
            # 计算最大长度
            max_length = max_input_length + max_output_length + 1
            # 将处理后的输入 ID 和标签添加到批次列表中,限制长度
            batched_input_ids.append(input_ids[:max_length])
            batched_labels.append(labels[:max_length])
    
        # 删除不再使用的变量以释放内存
        del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
        # 清空 CUDA 缓存以释放显存
        torch.cuda.empty_cache()
    
        # 返回输入 ID 和标签的字典
        return {'input_ids': batched_input_ids, 'labels': batched_labels}
# 处理批量评估的函数,返回输入和输出 ID 的字典
def process_batch_eval(
        batch: Mapping[str, Sequence],  # 输入批次,包含消息的映射
        tokenizer: PreTrainedTokenizer,  # 预训练的分词器
        max_input_length: int,  # 输入的最大长度
        max_output_length: int,  # 输出的最大长度
        combine: bool,  # 是否组合消息
) -> dict[str, list]:  # 返回类型为包含输入和输出 ID 的字典
    # 从批次中提取对话消息
    batched_conv = batch['messages']
    # 存储处理后的输入 ID 的列表
    batched_input_ids = []
    # 存储处理后的输出 ID 的列表
    batched_output_ids = []

    # 遍历每个对话
    for conv in batched_conv:
        if combine:  # 如果选择组合模式
            # 应用聊天模板对对话进行编码
            new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
            # 将新的输入 ID 赋值给输入 ID
            input_ids = new_input_ids
            # 获取最后一个助手消息的索引
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            # 分割输出提示和输出 ID
            output_prompt, output_ids = (
                input_ids[:1],  # 取第一个输入 ID 作为输出提示
                input_ids[last_assistant_index:],  # 取从助手消息开始的输出 ID
            )
            output_ids.append(151336)  # 添加结束符
            # 将处理后的输入 ID 添加到列表中,限制长度
            batched_input_ids.append(
                input_ids[:max_input_length] + output_prompt[:1]
            )
            # 将处理后的输出 ID 添加到列表中,限制长度
            batched_output_ids.append(output_ids[:max_output_length])
        else:  # 如果选择不组合模式
            input_ids = [151331, 151333]  # 初始化输入 ID
            # 遍历对话中的每个消息
            for message in conv:
                if len(input_ids) >= max_input_length:  # 如果输入长度超过最大限制
                    break  # 跳出循环
                else:
                    # 处理当前消息
                    message = process_message(message)
                    # 应用聊天模板对消息进行编码
                    new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                    if message['role'] == 'assistant':  # 如果消息来自助手
                        output_prompt, output_ids = (
                            new_input_ids[:1],  # 取第一个新的输入 ID 作为输出提示
                            new_input_ids[1:],  # 取剩余的输入 ID 作为输出 ID
                        )
                        output_ids.append(151336)  # 添加结束符
                        # 将处理后的输入 ID 添加到列表中,限制长度
                        batched_input_ids.append(
                            input_ids[:max_input_length] + output_prompt[:1]
                        )
                        # 将处理后的输出 ID 添加到列表中,限制长度
                        batched_output_ids.append(output_ids[:max_output_length])
                    # 更新输入 ID
                    input_ids += new_input_ids

    # 删除不再需要的变量以释放内存
    del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
    # 清空 GPU 缓存
    torch.cuda.empty_cache()

    # 返回包含输入和输出 ID 的字典
    return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}


# 加载分词器和模型的函数
def load_tokenizer_and_model(
        model_dir: str,  # 模型目录
        peft_config: Optional[PeftConfig] = None,  # 可选的配置
):
    # 从指定目录加载分词器
    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    if peft_config is not None:  # 如果提供了 PEFT 配置
        # 从指定目录加载因果语言模型
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            empty_init=False,
            use_cache=False,
            torch_dtype=torch.bfloat16  # 必须使用 BFloat 16
        )
        # 应用 PEFT 模型配置
        model = get_peft_model(model, peft_config)
        # 打印可训练参数
        model.print_trainable_parameters()
    else:  # 如果没有 PEFT 配置
        # 从指定目录加载因果语言模型
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            empty_init=False,
            use_cache=False,
            torch_dtype=torch.bfloat16
        )
    # 返回分词器和模型
    return tokenizer, model


# 计算指标的函数
def compute_metrics(eval_preds: EvalPrediction, tokenizer):  
    # 解包评估预测和标签 ID
    batched_pred_ids, batched_label_ids = eval_preds
    # 初始化一个字典,用于存储不同评估指标的分数
        metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
        # 遍历批次中的预测 ID 和标签 ID
        for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
            # 将预测 ID 解码为文本,并去除首尾空格
            pred_txt = tokenizer.decode(pred_ids).strip()
            # 将标签 ID 解码为文本,并去除首尾空格
            label_txt = tokenizer.decode(label_ids).strip()
            # 使用结巴分词对预测文本进行分词
            pred_tokens = list(jieba.cut(pred_txt))
            # 使用结巴分词对标签文本进行分词
            label_tokens = list(jieba.cut(label_txt))
            # 创建 Rouge 评分的实例
            rouge = Rouge()
            # 计算 Rouge 分数,获取得分字典
            scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
            # 遍历第一个得分字典的键值对
            for k, v in scores[0].items():
                # 将 F 值乘以 100 后四舍五入,存储到对应的指标列表中
                metrics_dct[k].append(round(v['f'] * 100, 4))
            # 计算 BLEU-4 分数,并存储到字典中
            metrics_dct['bleu-4'].append(
                sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
        # 返回每个指标的平均分数字典
        return {k: np.mean(v) for k, v in metrics_dct.items()}
# 定义命令行工具的主入口函数
@app.command()
def main(
        # 指定数据目录,帮助信息为空
        data_dir: Annotated[str, typer.Argument(help='')],
        # 指定模型目录或模型配置文件路径,并提供帮助信息
        model_dir: Annotated[
            str,
            typer.Argument(
                help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
            ),
        ],
        # 指定配置文件路径,帮助信息为空
        config_file: Annotated[str, typer.Argument(help='')],
        # 自动恢复训练的检查点选项,默认值为空字符串
        auto_resume_from_checkpoint: str = typer.Argument(
            default='',
            help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
        ),
):
    # 从配置文件加载微调配置
    ft_config = FinetuningConfig.from_file(config_file)
    # 加载分词器和模型,传入微调配置
    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
    # 创建数据管理对象,传入数据目录和数据配置
    data_manager = DataManager(data_dir, ft_config.data_config)

    # 获取训练数据集,处理批次数据
    train_dataset = data_manager.get_dataset(
        Split.TRAIN,
        functools.partial(
            process_batch,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    # 打印训练数据集
    print('train_dataset:', train_dataset)
    # 获取验证数据集,处理批次数据
    val_dataset = data_manager.get_dataset(
        Split.VALIDATION,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    # 如果验证数据集不为空,则打印
    if val_dataset is not None:
        print('val_dataset:', val_dataset)
    # 获取测试数据集,处理批次数据
    test_dataset = data_manager.get_dataset(
        Split.TEST,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    # 如果测试数据集不为空,则打印
    if test_dataset is not None:
        print('test_dataset:', test_dataset)

    # 启用模型的梯度检查点
    model.gradient_checkpointing_enable()
    # 启用模型输入的梯度计算
    model.enable_input_require_grads()
    
    # 设置生成配置的填充标记ID
    ft_config.training_args.generation_config.pad_token_id = (
        151329
    )
    # 设置生成配置的结束标记ID
    ft_config.training_args.generation_config.eos_token_id = [
        151329, 151336, 151338
    ]

    # 初始化序列到序列训练器
    trainer = Seq2SeqTrainer(
        model=model,
        args=ft_config.training_args,
        # 设置数据整理器
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            padding='longest',
            return_tensors='pt',
        ),
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        # 设置计算指标的函数
        compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
    )

    # 如果未选择自动恢复检查点,则开始训练
    if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
        trainer.train()
    else:  # 如果不是首次训练,则执行以下逻辑
        output_dir = ft_config.training_args.output_dir  # 获取输出目录路径
        dirlist = os.listdir(output_dir)  # 列出输出目录下的所有文件和文件夹
        checkpoint_sn = 0  # 初始化检查点序号为 0
        for checkpoint_str in dirlist:  # 遍历输出目录中的每个项
            if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:  # 检查项是否包含 "eckpoint" 且不包含 "tmp"
                checkpoint = int(checkpoint_str.replace("checkpoint-", ""))  # 提取数字部分作为检查点编号
                if checkpoint > checkpoint_sn:  # 如果当前检查点编号大于已记录的最大值
                    checkpoint_sn = checkpoint  # 更新最大检查点编号
        if auto_resume_from_checkpoint.upper() == "YES":  # 如果设置为自动从检查点恢复
            if checkpoint_sn > 0:  # 如果找到有效的检查点编号
                model.gradient_checkpointing_enable()  # 启用模型的梯度检查点功能
                model.enable_input_require_grads()  # 启用输入的梯度计算
                checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))  # 构建检查点目录的完整路径
                print("resume checkpoint from checkpoint-" + str(checkpoint_sn))  # 输出正在恢复的检查点信息
                trainer.train(resume_from_checkpoint=checkpoint_directory)  # 从指定检查点恢复训练
            else:  # 如果没有找到有效的检查点
                trainer.train()  # 开始新的训练
        else:  # 如果不自动恢复检查点
            if auto_resume_from_checkpoint.isdigit():  # 如果指定的恢复检查点是数字
                if int(auto_resume_from_checkpoint) > 0:  # 检查点编号大于 0
                    checkpoint_sn = int(auto_resume_from_checkpoint)  # 设置检查点编号
                    model.gradient_checkpointing_enable()  # 启用模型的梯度检查点功能
                    model.enable_input_require_grads()  # 启用输入的梯度计算
                    checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))  # 构建检查点目录的完整路径
                    print("resume checkpoint from checkpoint-" + str(checkpoint_sn))  # 输出正在恢复的检查点信息
                    trainer.train(resume_from_checkpoint=checkpoint_directory)  # 从指定检查点恢复训练
            else:  # 如果指定的恢复检查点不是有效数字
                print(auto_resume_from_checkpoint,  # 输出自动恢复检查点的信息
                      "The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")  # 提示用户指定的检查点不存在

    if test_dataset is not None:  # 如果测试数据集不为空
        trainer.predict(test_dataset)  # 使用训练器对测试数据集进行预测
# 检查当前模块是否是主程序
if __name__ == '__main__':
    # 调用应用程序的主函数
    app()

.\chatglm4-finetune\finetune_demo\finetune_vision.py

# -*- coding: utf-8 -*-  # 指定文件编码为 UTF-8
import os  # 导入操作系统功能模块
import jieba  # 导入中文分词库
import dataclasses as dc  # 导入数据类模块并重命名为 dc
import functools  # 导入高阶函数模块
from collections.abc import Callable, Mapping, Sequence  # 导入集合相关的类型
from pathlib import Path  # 导入路径处理模块
from typing import Annotated, Any, Union  # 导入类型提示相关模块
import numpy as np  # 导入 NumPy 数组处理库
import ruamel.yaml as yaml  # 导入 YAML 处理库
import torch  # 导入 PyTorch 深度学习框架
import typer  # 导入命令行界面库
from datasets import Dataset, Split  # 从 datasets 导入数据集和分割
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction  # 导入 BLEU 分数计算功能
from peft import PeftConfig, get_peft_config, get_peft_model  # 导入 PEFT 配置相关模块
from rouge_chinese import Rouge  # 导入中文 ROUGE 评估工具
from torch import nn  # 导入 PyTorch 神经网络模块
from transformers import (  # 从 transformers 库导入各种模型和工具
    AutoModelForCausalLM,  # 自动加载因果语言模型
    AutoTokenizer,  # 自动加载分词器
    EvalPrediction,  # 导入评估预测结果的类
    GenerationConfig,  # 导入生成配置类
    PreTrainedTokenizer,  # 导入预训练分词器
    Seq2SeqTrainingArguments,  # 导入序列到序列训练参数类
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq  # 导入序列到序列数据整理类并重命名
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer  # 导入序列到序列训练器并重命名
from datasets import load_dataset, DatasetDict, NamedSplit  # 导入数据集加载和字典功能
from typing import Optional  # 导入可选类型提示
from PIL import Image  # 导入图像处理库

app = typer.Typer(pretty_exceptions_show_locals=False)  # 创建 Typer 应用,禁用本地异常显示
img = Image.new('L', (224, 224), 0).convert('RGB')  # 创建一个 224x224 的黑色灰度图像并转换为 RGB

class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):  # 定义数据整理类,继承自 _DataCollatorForSeq2Seq
    def __call__(self, features, return_tensors=None):  # 定义调用方法,处理输入特征
        output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)  # 提取输出 ID
        if output_ids is not None:  # 检查输出 ID 是否存在
            max_output_length = max(len(out) for out in output_ids)  # 获取最大输出长度
            if self.pad_to_multiple_of is not None:  # 如果需要填充到特定倍数
                max_output_length = (  # 计算填充后的最大输出长度
                        (
                                max_output_length + self.pad_to_multiple_of - 1) //
                        self.pad_to_multiple_of * self.pad_to_multiple_of
                )
            for feature in features:  # 遍历特征进行填充
                remainder = [self.tokenizer.pad_token_id] * (  # 创建填充列表
                        max_output_length - len(feature['output_ids'])
                )
                if isinstance(feature['output_ids'], list):  # 检查输出 ID 类型
                    feature['output_ids'] = feature['output_ids'] + remainder  # 列表形式直接拼接
                else:  # 否则使用 NumPy 进行拼接
                    feature['output_ids'] = np.concatenate(
                        [feature['output_ids'], remainder]
                    ).astype(np.int64)  # 转换为 int64 类型
        return super().__call__(features, return_tensors)  # 调用父类的方法返回结果


class Seq2SeqTrainer(_Seq2SeqTrainer):  # 定义序列到序列训练器类,继承自 _Seq2SeqTrainer
    # Not Support for apex  # 说明不支持 apex
    def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:  # 定义训练步骤

        model.train()  # 将模型设置为训练模式
        inputs = self._prepare_inputs(inputs)  # 准备输入数据

        with self.compute_loss_context_manager():  # 计算损失的上下文管理器
            loss = self.compute_loss(model, inputs)  # 计算模型的损失

        if self.args.n_gpu > 1:  # 检查是否使用多 GPU
            loss = loss.mean()  # 如果是,取平均损失
        self.accelerator.backward(loss)  # 反向传播损失
        detached_loss = loss.detach() / self.args.gradient_accumulation_steps  # 分离损失并进行梯度累积
        del inputs  # 删除输入数据以释放内存
        torch.cuda.empty_cache()  # 清空 CUDA 缓存
        return detached_loss  # 返回分离后的损失

    def prediction_step(  # 定义预测步骤
            self,
            model: nn.Module,  # 输入模型
            inputs: dict,  # 输入字典
            prediction_loss_only: bool,  # 是否仅计算预测损失
            ignore_keys=None,  # 可选的忽略键
            **gen_kwargs,  # 其他生成参数
    # 返回一个包含可选浮点数和两个可选张量的元组
    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
    
        # 禁用梯度计算,以减少内存使用和提高速度
        with torch.no_grad():
            # 如果设置为使用生成进行预测,则提取输出 ID
            if self.args.predict_with_generate:
                output_ids = inputs.pop('output_ids', None)
            # 调用父类的预测步骤方法,计算损失和生成的标记
            loss, generated_tokens, labels = super().prediction_step(
                model=model,  # 传入模型
                inputs=inputs,  # 传入输入数据
                prediction_loss_only=prediction_loss_only,  # 是否仅计算损失
                ignore_keys=ignore_keys,  # 忽略的键
                **gen_kwargs  # 其他生成参数
            )
    
            # 如果生成的标记不为空,则裁剪标记以移除输入部分
            if generated_tokens is not None:
                generated_tokens = generated_tokens[:, inputs["input_ids"].size()[1]:]
    
            # 如果设置为使用生成进行预测,则将标签设置为输出 ID
            if self.args.predict_with_generate:
                labels = output_ids
    
            # 删除输入数据和输出 ID,以释放内存
            del inputs, output_ids
            # 清空 CUDA 缓存以释放显存
            torch.cuda.empty_cache()
    
        # 返回损失、生成的标记和标签
        return loss, generated_tokens, labels
# 使用 dataclass 装饰器定义数据配置类
@dc.dataclass
class DataConfig(object):
    # 训练文件的可选路径
    train_file: Optional[str] = None
    # 验证文件的可选路径
    val_file: Optional[str] = None
    # 测试文件的可选路径
    test_file: Optional[str] = None
    # 处理数据时使用的进程数量的可选值
    num_proc: Optional[int] = None

    # 定义一个只读属性,用于获取训练文件的后缀
    @property
    def data_format(self) -> str:
        # 返回训练文件的文件扩展名
        return Path(self.train_file).suffix

    # 定义一个只读属性,用于获取数据文件的字典
    @property
    def data_files(self) -> dict[NamedSplit, str]:
        # 生成包含数据集划分与对应文件路径的字典
        return {
            split: data_file
            for split, data_file in zip(
                # 列出数据集的划分类型
                [Split.TRAIN, Split.VALIDATION, Split.TEST],
                # 列出对应的文件路径
                [self.train_file, self.val_file, self.test_file],
            )
            # 仅包含文件路径不为 None 的条目
            if data_file is not None
        }


# 使用 dataclass 装饰器定义微调配置类
@dc.dataclass
class FinetuningConfig(object):
    # 数据配置的实例
    data_config: DataConfig

    # 最大输入长度
    max_input_length: int
    # 最大输出长度
    max_output_length: int
    # 是否合并数据的标志
    combine: bool
    # 是否冻结某些参数的标志
    freezeV: bool

    # 训练参数的实例,使用默认工厂函数初始化
    training_args: Seq2SeqTrainingArguments = dc.field(
        default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
    )
    # 可选的 Peft 配置
    peft_config: Optional[PeftConfig] = None

    # 类的初始化后处理函数
    def __post_init__(self):
        # 如果不进行评估或验证文件为空,则禁用评估
        if not self.training_args.do_eval or self.data_config.val_file is None:
            self.training_args.do_eval = False
            # 设置评估策略为不评估
            self.training_args.evaluation_strategy = 'no'
            # 将验证文件设置为 None
            self.data_config.val_file = None
        else:
            # 设置评估批次大小,如果未定义则使用训练批次大小
            self.training_args.per_device_eval_batch_size = (
                    self.training_args.per_device_eval_batch_size
                    or self.training_args.per_device_train_batch_size
            )

    # 从字典创建 FinetuningConfig 实例的类方法
    @classmethod
    def from_dict(cls, **kwargs) -> 'FinetuningConfig':
        # 从字典中获取训练参数
        training_args = kwargs.get('training_args', None)
        # 如果训练参数存在且不是 Seq2SeqTrainingArguments 类型
        if training_args is not None and not isinstance(
                training_args, Seq2SeqTrainingArguments
        ):
            # 获取生成配置
            gen_config = training_args.get('generation_config')
            # 如果生成配置不是 GenerationConfig 类型,则进行转换
            if not isinstance(gen_config, GenerationConfig):
                training_args['generation_config'] = GenerationConfig(
                    **gen_config
                )
            # 将训练参数转换为 Seq2SeqTrainingArguments 实例
            kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)

        # 从字典中获取数据配置
        data_config = kwargs.get('data_config')
        # 如果数据配置不是 DataConfig 类型,则进行转换
        if not isinstance(data_config, DataConfig):
            kwargs['data_config'] = DataConfig(**data_config)

        # 从字典中获取 Peft 配置
        peft_config = kwargs.get('peft_config', None)
        # 如果 Peft 配置存在且不是 PeftConfig 类型,则进行转换
        if peft_config is not None and not isinstance(peft_config, PeftConfig):
            kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
        # 创建 FinetuningConfig 实例并返回
        return cls(**kwargs)

    # 从文件创建 FinetuningConfig 实例的类方法
    @classmethod
    def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
        # 将路径转换为 Path 对象
        path = Path(path)
        # 创建 YAML 解析器,使用安全模式
        parser = yaml.YAML(typ='safe', pure=True)
        # 设置解析器的缩进格式
        parser.indent(mapping=2, offset=2, sequence=4)
        # 设置默认的流样式为非流式
        parser.default_flow_style = False
        # 解析 YAML 文件并加载内容
        kwargs = parser.load(path)
        # 从解析后的字典中创建 FinetuningConfig 实例并返回
        return cls.from_dict(**kwargs)


# 定义一个加载数据集的私有函数
def _load_datasets(
        # 数据目录路径
        data_dir: str,
        # 数据格式
        data_format: str,
        # 数据文件字典
        data_files: dict[NamedSplit, str],
        # 进程数量的可选值
        num_proc: Optional[int],
) -> DatasetDict:
    # 检查数据格式是否为 JSON Lines 格式
        if data_format == '.jsonl':
            # 加载数据集,指定数据目录和文件,未划分子集,使用指定进程数
            dataset_dct = load_dataset(
                data_dir,
                data_files=data_files,
                split=None,
                num_proc=num_proc,
            )
        else:
            # 如果数据格式不是支持的格式,则引发未实现错误
            raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
        # 返回加载的数据集字典
        return dataset_dct
# 数据管理器类,用于处理数据集相关操作
class DataManager(object):
    # 初始化方法,接收数据目录和数据配置作为参数
    def __init__(self, data_dir: str, data_config: DataConfig):
        # 从数据配置中获取进程数量
        self._num_proc = data_config.num_proc

        # 加载数据集,并存储为字典
        self._dataset_dct = _load_datasets(
            data_dir,  # 数据目录
            data_config.data_format,  # 数据格式
            data_config.data_files,  # 数据文件列表
            self._num_proc,  # 进程数量
        )

    # 获取指定划分的数据集,如果不存在则返回 None
    def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
        return self._dataset_dct.get(split, None)  # 从字典中获取数据集

    # 获取处理过的数据集,支持批处理和原始列删除
    def get_dataset(
            self,
            split: NamedSplit,  # 数据集划分
            process_fn: Callable[[dict[str, Any]], dict[str, Any]],  # 处理函数
            batched: bool = True,  # 是否批处理
            remove_orig_columns: bool = True,  # 是否移除原始列
    ) -> Optional[Dataset]:
        # 获取原始数据集
        orig_dataset = self._get_dataset(split)
        if orig_dataset is None:  # 如果数据集不存在
            return  # 返回 None
        if remove_orig_columns:  # 如果需要移除原始列
            remove_columns = orig_dataset.column_names  # 获取列名
        else:
            remove_columns = None  # 不移除列
        # 对原始数据集应用处理函数并返回结果
        return orig_dataset.map(
            process_fn,  # 处理函数
            batched=batched,  # 是否批处理
            remove_columns=remove_columns,  # 需要移除的列
            num_proc=self._num_proc,  # 进程数量
            # 默认的 orig_dataset.map 参数,可以调整为更小
            # https://github.com/THUDM/GLM-4/issues/277
            writer_batch_size=1000,  # 写入时的批处理大小
            batch_size=1000,  # 处理时的批处理大小
        )


# 处理批次数据的函数
def process_batch(
        batch: Mapping[str, Sequence],  # 输入批次数据
        tokenizer: PreTrainedTokenizer,  # 预训练的分词器
        max_input_length: int,  # 最大输入长度
        max_output_length: int,  # 最大输出长度
        combine: bool,  # 是否合并
) -> dict[str, list]:  # 返回处理后的字典
    # 获取批次中的消息
    batched_conv = batch['messages']
    # 初始化各类批处理列表
    batched_input_ids = []  # 输入 ID 列表
    batched_attention_mask = []  # 注意力掩码列表
    batched_position_ids = []  # 位置 ID 列表
    batched_labels = []  # 标签列表
    batched_images = []  # 图像列表

    # 计算最大长度
    max_length = max_input_length + max_output_length
    # 遍历每个批次的对话
        for conv in batched_conv:
            # 初始化输入 ID 列表
            input_ids = [151331, 151333]
            # 初始化注意力掩码列表
            attention_mask = [1, 1]
            # 创建位置 ID 列表
            position_ids = list(range(len(input_ids)))
            # 初始化损失掩码列表
            loss_masks = [False, False]
            # 初始化图像列表
            images = []
            
            # 检查对话的第一个元素是否有图像
            if conv[0].get('image'):
                # 打开图像并转换为 RGB 模式
                conv[0]['image'] = Image.open(conv[0]['image']).convert('RGB')
            else:
                # 如果没有图像,则使用默认图像
                conv[0]['image'] = img
    
            # 遍历对话中的每条消息
            for message in conv:
                # 设置损失掩码值,基于消息的角色判断
                loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
                # 应用聊天模板,对消息进行标记化
                new_input_ids_all = tokenizer.apply_chat_template(
                    [message],
                    tokenize=True,
                    return_dict=True,
                    padding=True
                )
                # 提取新输入 ID,去掉特殊标记
                new_input_ids = new_input_ids_all['input_ids'][0][2:]
                # 提取新注意力掩码,去掉特殊标记
                new_attention_mask = new_input_ids_all['attention_mask'][0][2:]
                # 创建新的位置 ID 列表
                new_position_ids = list(range(position_ids[-1] + 1, position_ids[-1] + 1 + len(new_input_ids)))
                # 如果消息有图像,则添加到图像列表
                if message.get('image'):  # 仅处理一张图像
                    images.append(new_input_ids_all['images'])
    
                # 创建新的损失掩码
                new_loss_masks = [loss_mask_val] * len(new_input_ids)
                # 更新输入 ID 列表
                input_ids += new_input_ids
                # 更新注意力掩码列表
                attention_mask += new_attention_mask
                # 更新位置 ID 列表
                position_ids += new_position_ids
                # 更新损失掩码列表
                loss_masks += new_loss_masks
    
            # 添加结束标记到输入 ID
            input_ids.append(151336)  # EOS
            # 添加结束标记的注意力掩码
            attention_mask.append(1)
            # 更新位置 ID 列表以包含结束标记
            position_ids.append(len(position_ids))
            # 添加结束标记的损失掩码
            loss_masks.append(False)
    
            # 初始化标签列表
            labels = []
            # 遍历输入 ID 和损失掩码,生成标签
            for input_id, mask in zip(input_ids, loss_masks):
                if mask:
                    # 如果掩码为真,则将输入 ID 添加到标签
                    labels.append(input_id)
                else:
                    # 否则添加 -100 表示忽略
                    labels.append(-100)
    
            # 添加批处理输入 ID 到列表,限制长度
            batched_input_ids.append(input_ids[:max_length])
            # 添加批处理注意力掩码到列表,限制长度
            batched_attention_mask.append(attention_mask[:max_length])
            # 添加批处理位置 ID 到列表,限制长度
            batched_position_ids.append(position_ids[:max_length])
            # 添加批处理标签到列表,限制长度
            batched_labels.append(labels[:max_length])
            # 添加第一张图像到批处理图像列表
            batched_images.append(images[0][0])
    
        # 删除临时变量以释放内存
        del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
        # 清空 GPU 缓存以释放内存
        torch.cuda.empty_cache()
    
        # 返回结果字典,包含所有批处理数据
        return {
            'input_ids': batched_input_ids,
            'attention_mask': batched_attention_mask,
            'position_ids': batched_position_ids,
            'labels': batched_labels,
            'images': batched_images
        }
# 处理批量评估的函数,接受批量数据、分词器、输入输出长度等参数,返回处理结果字典
def process_batch_eval(
        batch: Mapping[str, Sequence],  # 批量输入,包含消息的映射
        tokenizer: PreTrainedTokenizer,  # 预训练的分词器
        max_input_length: int,  # 最大输入长度限制
        max_output_length: int,  # 最大输出长度限制
        combine: bool,  # 是否合并处理标志
) -> dict[str, list]:  # 返回字典,键为字符串,值为列表
    # 从批量数据中提取消息部分
    batched_conv = batch['messages']
    # 初始化各类存储列表
    batched_input_ids = []  # 存储输入 ID 列表
    batched_attention_mask = []  # 存储注意力掩码列表
    batched_position_ids = []  # 存储位置 ID 列表
    batched_output_ids = []  # 存储输出 ID 列表
    batched_images = []  # 存储图像列表

    # 遍历每个对话
    for conv in batched_conv:
        # 如果对话包含图像,则打开并转换为 RGB 格式
        if conv[0].get('image'):
            image = Image.open(conv[0]['image']).convert('RGB')
        else:
            # 如果没有图像,使用默认图像
            image = img   
        
        # 将图像存回对话数据中
        conv[0]['image'] = image
        # 应用聊天模板分词,并返回分词结果
        new_input_ids_all = tokenizer.apply_chat_template(
            conv,
            tokenize=True,  # 是否分词
            return_dict=True,  # 返回字典格式
            padding=True  # 是否进行填充
        )

        # 提取分词后的输入 ID
        input_ids = new_input_ids_all['input_ids'][0]
        # 提取注意力掩码
        attention_mask = new_input_ids_all['attention_mask'][0]
        # 生成位置 ID 列表
        position_ids = list(range(len(input_ids)))

        # 初始化对话部分列表
        dialogue_parts = [0]
        # 遍历输入 ID,寻找对话分隔符
        for idx, token_id in enumerate(input_ids):
            if token_id == 151337:  # 特定标识符表示对话分隔
                dialogue_parts.append(idx + 1)

        # 如果没有对话部分或最后一部分未结束,添加结束位置
        if not dialogue_parts or dialogue_parts[-1] != len(input_ids):
            dialogue_parts.append(len(input_ids))

            # 将对话拆分为多个对话段
        for end_idx in range(1, len(dialogue_parts)):
            # 获取当前对话段的输入
            input_segment = input_ids[:dialogue_parts[end_idx]]
            # 获取当前对话段的注意力掩码
            attention_segment = attention_mask[:dialogue_parts[end_idx]]
            # 获取当前对话段的位置 ID
            position_segment = position_ids[:dialogue_parts[end_idx]]
            # 获取当前对话段的输出,添加结束符
            output_segment = input_ids[dialogue_parts[end_idx - 1]:dialogue_parts[end_idx]]
            output_segment.append(151336)  # 添加结束标识符

            # 将处理结果添加到批量列表中
            batched_input_ids.append(input_segment[:max_input_length])  # 限制输入长度
            batched_attention_mask.append(attention_segment[:max_input_length])  # 限制注意力掩码长度
            batched_position_ids.append(position_segment[:max_input_length])  # 限制位置 ID 长度
            batched_output_ids.append(output_segment[:max_output_length])  # 限制输出长度
            batched_images.append(new_input_ids_all['images'][0])  # 添加图像

    # 清理不再使用的变量以释放内存
    del batched_conv, input_ids, attention_mask, position_ids, new_input_ids_all, output_segment
    # 清空 CUDA 缓存以释放 GPU 内存
    torch.cuda.empty_cache()

    # 返回处理后的结果字典
    return {
        'input_ids': batched_input_ids,  # 输入 ID 列表
        'attention_mask': batched_attention_mask,  # 注意力掩码列表
        'position_ids': batched_position_ids,  # 位置 ID 列表
        'output_ids': batched_output_ids,  # 输出 ID 列表
        'images': batched_images  # 图像列表
    }


# 加载分词器和模型的函数,接受模型目录和可选的 PEFT 配置
def load_tokenizer_and_model(
        model_dir: str,  # 模型目录
        peft_config: Optional[PeftConfig] = None,  # 可选的 PEFT 配置
):
    # 从预训练模型目录加载分词器,信任远程代码
    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    # 如果提供了 PEFT 配置,则加载模型
    if peft_config is not None:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,  # 模型目录
            trust_remote_code=True,  # 信任远程代码
            empty_init=False,  # 不进行空初始化
            use_cache=False,  # 禁用缓存
            torch_dtype=torch.bfloat16  # 使用 BFloat 16 数据类型
        )
        # 应用 PEFT 模型配置
        model = get_peft_model(model, peft_config)
        # 打印可训练参数
        model.print_trainable_parameters()
    # 如果前面的条件不满足,执行以下代码
        else:
            # 从指定的模型目录加载预训练的因果语言模型,允许使用远程代码
            model = AutoModelForCausalLM.from_pretrained(
                model_dir,                          # 模型目录路径
                trust_remote_code=True,            # 信任远程代码
                empty_init=False,                  # 不使用空初始化
                use_cache=False,                   # 不使用缓存
                torch_dtype=torch.bfloat16         # 使用 bfloat16 数据类型
            )
        # 返回分词器和加载的模型
        return tokenizer, model
# 定义一个计算评估指标的函数,接收评估预测和分词器
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
    # 解包评估预测,获取预测ID和标签ID
    batched_pred_ids, batched_label_ids = eval_preds
    # 初始化一个字典来存储各种指标的分数
    metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
    # 遍历每一组预测ID和标签ID
    for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
        # 使用分词器解码预测ID为文本,并去除首尾空白
        pred_txt = tokenizer.decode(pred_ids).strip()
        # 使用分词器解码标签ID为文本,并去除首尾空白
        label_txt = tokenizer.decode(label_ids).strip()
        # 对预测文本进行分词,生成token列表
        pred_tokens = list(jieba.cut(pred_txt))
        # 对标签文本进行分词,生成token列表
        label_tokens = list(jieba.cut(label_txt))
        # 创建Rouge评分对象
        rouge = Rouge()
        # 计算Rouge分数,得到各项评分
        scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
        # 遍历评分结果,保存F值到指标字典中
        for k, v in scores[0].items():
            metrics_dct[k].append(round(v['f'] * 100, 4))
        # 计算Bleu-4分数并保存到指标字典中
        metrics_dct['bleu-4'].append(
            sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
    # 返回每个指标的平均值
    return {k: np.mean(v) for k, v in metrics_dct.items()}


# 定义主命令行函数,接收多个参数
@app.command()
def main(
        # 数据目录参数
        data_dir: Annotated[str, typer.Argument(help='')],
        # 模型目录参数,包含模型配置的路径或ID
        model_dir: Annotated[
            str,
            typer.Argument(
                help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
            ),
        ],
        # 配置文件路径参数
        config_file: Annotated[str, typer.Argument(help='')],
        # 自动恢复检查点的参数,默认值为空字符串
        auto_resume_from_checkpoint: str = typer.Argument(
            default='',
            help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
        ),
):
    # 从配置文件加载微调配置
    ft_config = FinetuningConfig.from_file(config_file)
    # 加载分词器和模型
    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
    
    # 如果配置中冻结视觉参数,则不更新这些参数
    if ft_config.freezeV:
        for param in model.transformer.vision.parameters():
            param.requires_grad = False
    # 创建数据管理器,负责加载数据
    data_manager = DataManager(data_dir, ft_config.data_config)

    # 获取训练数据集,进行批处理
    train_dataset = data_manager.get_dataset(
        Split.TRAIN,
        functools.partial(
            process_batch,
            combine=ft_config.combine, # 目前未使用的组合参数
            tokenizer=tokenizer,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    # 打印训练数据集的信息
    print('train_dataset:', train_dataset)

    # 获取验证数据集,进行批处理
    val_dataset = data_manager.get_dataset(
        Split.VALIDATION,
        functools.partial(
            process_batch_eval,
            combine=ft_config.combine,
            tokenizer=tokenizer,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,

        ),
        batched=True,
    )

    # 如果验证数据集存在,则打印其信息
    if val_dataset is not None:
        print('val_dataset:', val_dataset)
    # 获取测试数据集,使用数据管理器,并对每个批次应用评估处理
        test_dataset = data_manager.get_dataset(
            Split.TEST,  # 指定数据集的拆分类型为测试集
            functools.partial(  # 使用偏函数来固定参数
                process_batch_eval,  # 处理每个批次的评估函数
                combine=ft_config.combine,  # 传入组合参数
                tokenizer=tokenizer,  # 传入分词器
                max_input_length=ft_config.max_input_length,  # 最大输入长度
                max_output_length=ft_config.max_output_length,  # 最大输出长度
            ),
            batched=True,  # 指定数据集为批处理模式
        )
        # 如果测试数据集不为空,则打印其内容
        if test_dataset is not None:
            print('test_dataset:', test_dataset)
    
        # 启用梯度检查点功能以节省内存
        model.gradient_checkpointing_enable()
        # 允许输入张量计算梯度
        model.enable_input_require_grads()
        
        # 设置生成配置中的填充标记ID
        ft_config.training_args.generation_config.pad_token_id = (
            151329  # 填充标记的ID
        )
        # 设置生成配置中的结束标记ID列表
        ft_config.training_args.generation_config.eos_token_id = [
            151329, 151336, 151338  # 结束标记的ID列表
        ]
    
        # 创建序列到序列训练器实例
        trainer = Seq2SeqTrainer(
            model=model,  # 指定使用的模型
            args=ft_config.training_args,  # 传入训练参数
            data_collator=DataCollatorForSeq2Seq(  # 数据整理器,用于处理输入数据
                tokenizer=tokenizer,  # 传入分词器
                padding='longest',  # 使用最长序列进行填充
                return_tensors='pt',  # 返回PyTorch张量
            ),
            train_dataset=train_dataset,  # 传入训练数据集
            eval_dataset=val_dataset,  # 传入评估数据集
            compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),  # 计算指标的偏函数
        )
    
        # 检查是否需要从检查点恢复训练
        if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
            trainer.train()  # 如果没有指定,直接开始训练
        else:
            output_dir = ft_config.training_args.output_dir  # 获取输出目录
            dirlist = os.listdir(output_dir)  # 列出输出目录中的文件
            checkpoint_sn = 0  # 初始化检查点序号
            # 遍历文件列表,查找有效的检查点
            for checkpoint_str in dirlist:
                if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:  # 检查文件名
                    checkpoint = int(checkpoint_str.replace("checkpoint-", ""))  # 提取检查点序号
                    if checkpoint > checkpoint_sn:  # 更新最大检查点序号
                        checkpoint_sn = checkpoint
            # 如果指定了要恢复的检查点
            if auto_resume_from_checkpoint.upper() == "YES":
                if checkpoint_sn > 0:  # 确保存在有效检查点
                    model.gradient_checkpointing_enable()  # 启用梯度检查点
                    model.enable_input_require_grads()  # 允许计算梯度
                    checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))  # 构造检查点路径
                    print("resume checkpoint from checkpoint-" + str(checkpoint_sn))  # 打印恢复信息
                    trainer.train(resume_from_checkpoint=checkpoint_directory)  # 从指定检查点恢复训练
                else:
                    trainer.train()  # 没有有效检查点,直接训练
            else:
                # 如果指定的恢复参数是数字
                if auto_resume_from_checkpoint.isdigit():
                    if int(auto_resume_from_checkpoint) > 0:  # 检查指定序号有效性
                        checkpoint_sn = int(auto_resume_from_checkpoint)  # 更新检查点序号
                        model.gradient_checkpointing_enable()  # 启用梯度检查点
                        model.enable_input_require_grads()  # 允许计算梯度
                        checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))  # 构造检查点路径
                        print("resume checkpoint from checkpoint-" + str(checkpoint_sn))  # 打印恢复信息
                        trainer.train(resume_from_checkpoint=checkpoint_directory)  # 从指定检查点恢复训练
                else:
                    # 如果指定的检查点无效,打印错误信息
                    print(auto_resume_from_checkpoint,
                          "The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
    # 检查测试数据集是否不为空
        if test_dataset is not None:
            # 如果测试数据集存在,则进行预测
            trainer.predict(test_dataset)
# 如果当前脚本是主程序入口
if __name__ == '__main__':
    # 调用应用程序函数
    app()

.\chatglm4-finetune\finetune_demo\inference.py

# 从 pathlib 库导入 Path 类,用于路径操作
from pathlib import Path
# 导入用于类型注解的 Annotated 和 Union
from typing import Annotated, Union
# 导入 typer 库,用于创建命令行界面
import typer
# 从 peft 库导入 PeftModelForCausalLM 模型
from peft import PeftModelForCausalLM
# 从 transformers 库导入自动模型和自动标记器
from transformers import (
    AutoModel,
    AutoTokenizer,
)
# 从 PIL 库导入 Image,用于图像处理
from PIL import Image
# 导入 PyTorch 库
import torch

# 创建一个 typer 应用,设置不显示局部变量的异常信息
app = typer.Typer(pretty_exceptions_show_locals=False)


# 定义加载模型和标记器的函数,接收模型目录和信任远程代码的标志
def load_model_and_tokenizer(
        model_dir: Union[str, Path], trust_remote_code: bool = True
):
    # 解析并规范化模型目录路径
    model_dir = Path(model_dir).expanduser().resolve()
    # 检查 adapter_config.json 是否存在于模型目录
    if (model_dir / 'adapter_config.json').exists():
        # 导入 JSON 库用于解析配置文件
        import json
        # 打开并读取 adapter_config.json 配置文件
        with open(model_dir / 'adapter_config.json', 'r', encoding='utf-8') as file:
            config = json.load(file)
        # 根据配置文件加载基础模型
        model = AutoModel.from_pretrained(
            config.get('base_model_name_or_path'),
            trust_remote_code=trust_remote_code,
            device_map='auto',
            torch_dtype=torch.bfloat16
        )
        # 从预训练模型加载 Peft 模型
        model = PeftModelForCausalLM.from_pretrained(
            model=model,
            model_id=model_dir,
            trust_remote_code=trust_remote_code,
        )
        # 获取标记器目录
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        # 如果没有 adapter_config.json,直接根据模型目录加载基础模型
        model = AutoModel.from_pretrained(
            model_dir,
            trust_remote_code=trust_remote_code,
            device_map='auto',
            torch_dtype=torch.bfloat16
        )
        # 设置标记器目录为模型目录
        tokenizer_dir = model_dir
    # 从预训练目录加载标记器
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir,
        trust_remote_code=trust_remote_code,
        encode_special_tokens=True,
        use_fast=False
    )
    # 返回加载的模型和标记器
    return model, tokenizer


# 定义主命令函数,接收模型目录参数
@app.command()
def main(
        model_dir: Annotated[str, typer.Argument(help='')],
):
    # 为 GLM-4 进行无工具微调的消息示例
    messages = [
        {
            "role": "user", "content": "#裙子#夏天",
        }
    ]

    # 为 GLM-4 进行有工具微调的消息示例
    # messages = [
    #     {
    #         "role": "system", "content": "",
    #         "tools":
    #             [
    #                 {
    #                     "type": "function",
    #                     "function": {
    #                         "name": "create_calendar_event",
    #                         "description": "Create a new calendar event",
    #                         "parameters": {
    #                             "type": "object",
    #                             "properties": {
    #                                 "title": {
    #                                     "type": "string",
    #                                     "description": "The title of the event"
    #                                 },
    #                                 "start_time": {
    #                                     "type": "string",
    #                                     "description": "The start time of the event in the format YYYY-MM-DD HH:MM"
    #                                 },
    #                                 "end_time": {
    #                                     "type": "string",
    #                                     "description": "事件结束时间,格式为 YYYY-MM-DD HH:MM"
    #                                 }
    #                             },
    #                             "required": [
    #                                 "title",  # 事件的标题是必填项
    #                                 "start_time",  # 事件的开始时间是必填项
    #                                 "end_time"  # 事件的结束时间是必填项
    #                             ]
    #                         }
    #                     }
    #                 }
    #             ]
    #
    #     },
    #     {
    #         "role": "user",  # 消息的角色为用户
    #         "content": "能帮我创建一个明天会议的日历事件吗?标题是\"团队会议\",开始时间是上午10:00,结束时间是上午11:00。"  # 用户请求创建日历事件的内容
    #     },
    # ]
    
    # 为 GLM-4V 微调准备消息
    # messages = [
    #     {
    #         "role": "user",  # 消息的角色为用户
    #         "content": "女孩可能希望观众做什么?",  # 用户的问题内容
    #         "image": Image.open("your Image").convert("RGB")  # 打开图像文件并转换为 RGB 格式
    #     }
    # ]

    model, tokenizer = load_model_and_tokenizer(model_dir)  # 加载模型和分词器
    inputs = tokenizer.apply_chat_template(  # 应用聊天模板格式化输入消息
        messages,  # 传入的消息列表
        add_generation_prompt=True,  # 添加生成提示
        tokenize=True,  # 对输入进行分词
        return_tensors="pt",  # 返回 PyTorch 张量
        return_dict=True  # 返回字典格式
    ).to(model.device)  # 将输入张量转移到模型的设备上
    generate_kwargs = {  # 定义生成时的参数
        "max_new_tokens": 1024,  # 生成的最大新标记数
        "do_sample": True,  # 允许随机采样
        "top_p": 0.8,  # 采样时的累积概率阈值
        "temperature": 0.8,  # 控制生成文本的随机性
        "repetition_penalty": 1.2,  # 重复惩罚因子
        "eos_token_id": model.config.eos_token_id,  # 结束标记的 ID
    }
    outputs = model.generate(**inputs, **generate_kwargs)  # 生成模型的输出
    response = tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True).strip()  # 解码生成的输出并去除特殊标记
    print("=========")  # 打印分隔符
    print(response)  # 输出生成的响应
# 如果当前脚本是主程序,则执行下面的代码
if __name__ == '__main__':
    # 调用应用程序的主函数
    app()

GLM-4-9B Chat 对话模型微调

Read this in English

本 demo 中,你将体验到如何微调 GLM-4-9B-Chat 对话开源模型(不支持视觉理解模型)。 请严格按照文档的步骤进行操作,以避免不必要的错误。

硬件检查

本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。微调的资源占用均按照
configs 文件夹中的配置文件设置

测试硬件信息:

  • OS: Ubuntu 22.04
  • Memory: 512GB
  • Python: 3.10.12 / 3.12.3 (如果您使用 Python 3.12.3 目前需要使用 git 源码安装 nltk)
  • CUDA Version: 12.3
  • GPU Driver: 535.104.05
  • GPU: NVIDIA A100-SXM4-80GB * 8
微调模型 微调方案 显存占用 权重保存点大小
GLM-4-9B-Chat lora (PEFT) 22G 17M
GLM-4-9B-Chat p-tuning v2 (PEFT) 21G 121M
GLM-4-9B-Chat SFT (Zero3 method) 80G (Each GPU,需要使用8张GPU) 20G
GLM-4V-9B lora (PEFT), 包含视觉模块 75G 37M
GLM-4V-9B SFT 本代码不支持 28G

GLM-4V-9B 微调无法可能正常使用 deepspeed,官方微调脚本仅做最基础的微调方案,更多优化需要开发者自行探索

在开始微调之前,请你先安装 basic_demo 中的依赖,并保证克隆了最新版本的模型仓库,同时您需要安装本目录下的依赖项:

pip install -r requirements.txt

NOTE: NLTK 3.8.1 部分代码可能尚未对 python 3.12
进行适配,该情况下的适配方法可参考issues #38

多轮对话格式

多轮对话微调示例采用 GLM-4 对话格式约定,对不同角色添加不同 loss_mask 从而在一遍计算中为多轮回复计算 loss

对于数据文件,样例采用如下格式

如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。

[
  {
    "messages": [
      {
        "role": "system",
        "content": "<system prompt text>",
        "tools": [
          {
            "name": "<tool name>",
            "args": {
              "<arg name>": "<arg value>"
            }
          }
          // Add more tools if needed
        ]
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      },
      // If Tool Using
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      },
      {
        "role": "observation",
        "content": "<observation prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response observation>"
      },
      // Multi_turns
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      }
    ]
  }
]

这里是一个不带有工具的例子:

{
  "messages": [
    {
      "role": "user",
      "content": "类型#裤*材质#牛仔布*风格#性感"
    },
    {
      "role": "assistant",
      "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。"
    }
  ]
}

这是一个带有工具调用的例子:

{
  "messages": [
    {
      "role": "system",
      "content": "",
      "tools": [
        {
          "type": "function",
          "function": {
            "name": "get_recommended_books",
            "description": "Get recommended books based on user's interests",
            "parameters": {
              "type": "object",
              "properties": {
                "interests": {
                  "type": "array",
                  "items": {
                    "type": "string"
                  },
                  "description": "The interests to recommend books for"
                }
              },
              "required": [
                "interests"
              ]
            }
          }
        }
      ]
    },
    {
      "role": "user",
      "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."
    },
    {
      "role": "assistant",
      "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"
    },
    {
      "role": "observation",
      "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"
    },
    {
      "role": "assistant",
      "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."
    }
  ]
}

这是一个视觉VQA微调的例子:

{
  "messages": [
    {
      "role": "user",
      "content": "图片中的动物是什么?",
      "image": "/root/images/0001.jpg"
    },
    {
      "role": "assistant",
      "content": "图片中有一只猫。"
    },
    {
      "role": "user",
      "content": "图片中的猫在做什么?"
    },
    {
      "role": "assistant",
      "content": "这只猫坐在或站在桌子上,桌上有很多食物。"
    }
  ]
}
  • system 角色为可选角色,但若存在 system 角色,其必须出现在 user
    角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 system 角色。
  • tools 字段为可选字段,若存在 tools 字段,其必须出现在 system
    角色之后,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 tools 字段。当 tools 字段存在时,system
    角色必须存在并且 content 字段为空。
  • GLM-4V-9B 不支持 tools 字段和 system 字段。并且 image 必须放在第一条消息中。 image
    字段需要放置置图片的 绝对路径

配置文件

微调配置文件位于 config 目录下,包括以下文件:

  1. ds_zereo_2 / ds_zereo_3.json: deepspeed 配置文件。
  2. `lora.yaml / ptuning_v2
  3. .yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下:
    • data_config 部分
      • train_file: 训练数据集的文件路径。
      • val_file: 验证数据集的文件路径。
      • test_file: 测试数据集的文件路径。
      • num_proc: 在加载数据时使用的进程数量。
    • max_input_length: 输入序列的最大长度。
    • max_output_length: 输出序列的最大长度。
    • training_args 部分
      • output_dir: 用于保存模型和其他输出的目录。
      • max_steps: 训练的最大步数。
      • per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。
      • dataloader_num_workers: 加载数据时使用的工作线程数量。
      • remove_unused_columns: 是否移除数据中未使用的列。
      • save_strategy: 模型保存策略(例如,每隔多少步保存一次)。
      • save_steps: 每隔多少步保存一次模型。
      • log_level: 日志级别(如 info)。
      • logging_strategy: 日志记录策略。
      • logging_steps: 每隔多少步记录一次日志。
      • per_device_eval_batch_size: 每个设备的评估批次大小。
      • evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。
      • eval_steps: 每隔多少步进行一次评估。
      • predict_with_generate: 是否使用生成模式进行预测。
    • generation_config 部分
      • max_new_tokens: 生成的最大新 token 数量。
    • peft_config 部分
      • peft_type: 使用的参数有效调整类型 (支持 LORA 和 PREFIX_TUNING)。
      • task_type: 任务类型,这里是因果语言模型 (不要改动)。
    • Lora 参数:
      • r: LoRA 的秩。
      • lora_alpha: LoRA 的缩放因子。
      • lora_dropout: 在 LoRA 层使用的 dropout 概率。
    • P-TuningV2 参数:
      • num_virtual_tokens: 虚拟 token 的数量。
      • num_attention_heads: 2: P-TuningV2 的注意力头数(不要改动)。
      • token_dim: 256: P-TuningV2 的 token 维度(不要改动)。

开始微调

通过以下代码执行 单机多卡/多机多卡 运行,这是使用 deepspeed 作为加速方案的,您需要安装 deepspeed。接着,按照此命令运行:

OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8  finetune.py  data/AdvertiseGen/  THUDM/glm-4-9b-chat  configs/lora.yaml # For Chat Fine-tune
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8  finetune_vision.py  data/CogVLM-311K/  THUDM/glm-4v-9b  configs/lora.yaml  # For VQA Fine-tune

通过以下代码执行 单机单卡 运行。

python finetune.py  data/AdvertiseGen/  THUDM/glm-4-9b-chat  configs/lora.yaml # For Chat Fine-tune
python finetune_vision.py  data/CogVLM-311K/  THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune

从保存点进行微调

如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式:

  1. yes, 自动从最后一个保存的 Checkpoint开始训练
  2. XX, 断点号数字 例 600 则从序号600 Checkpoint开始训练

例如,这就是一个从最后一个保存点继续微调的示例代码

python finetune.py  data/AdvertiseGen/  THUDM/glm-4-9b-chat  configs/lora.yaml yes

使用微调后的模型

在 inference.py 中验证微调后的模型

您可以在 finetune_demo/inference.py 中使用我们的微调后的模型,仅需要一行代码就能简单的进行测试。

python inference.py your_finetune_path

这样,得到的回答就微调后的回答了。

在本仓库的其他 demo 或者外部仓库使用微调后的模型

您可以在任何一个 demo 内使用我们的 LORA 和 全参微调的模型。这需要你自己按照以下教程进行修改代码。

  1. 使用finetune_demo/inference.py中读入模型的方式替换 demo 中读入模型的方式。

请注意,对于 LORA 和 P-TuningV2 我们没有合并训练后的模型,而是在adapter_config.json
中记录了微调型的路径,如果你的原始模型位置发生更改,则你应该修改adapter_config.jsonbase_model_name_or_path的路径。

def load_model_and_tokenizer(
        model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
    model_dir = _resolve_path(model_dir)
    if (model_dir / 'adapter_config.json').exists():
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model_dir
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=trust_remote_code
    )
    return model, tokenizer
  1. 读取微调的模型,请注意,你应该使用微调模型的位置,例如,若你的模型位置为/path/to/finetune_adapter_model
    ,原始模型地址为path/to/base_model,则你应该使用/path/to/finetune_adapter_model作为model_dir
  2. 完成上述操作后,就能正常使用微调的模型了,其他的调用方式没有变化。
  3. 本微调脚本没有测试过128K 1M等长文本的微调,长文本的微调需要更大显存的GPU设备,并且需要更高效的微调方案,需要开发者自行解决。

参考文献


@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short
Papers)},
pages={61--68},
year={2022}
}

@misc{tang2023toolalpaca,
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
year={2023},
eprint={2306.05301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}

标签:4v,self,9B,ids,源码,output,input,model,config
From: https://www.cnblogs.com/apachecn/p/18491989

相关文章