首页 > 其他分享 >聊聊GLM-4-9B开源模型的微调loss计算

聊聊GLM-4-9B开源模型的微调loss计算

时间:2024-06-12 10:36:06浏览次数:22  
标签:loss GLM 9B ids length role input message

概述

Github官方地址:GLM-4

网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

可了解其它loss计算的文章:
再聊多轮对话微调训练格式与长序列训练
聊聊ChatGLM2与ChatGLM3微调多轮对话的设计逻辑及源码分析
聊聊大模型多轮对话的训练及优化

微调

微调格式:

[
  {
    "messages": [
      {
        "role": "system",
        "content": "<system prompt text>",
        "tools": [
          {
            "name": "<tool name>",
            "args": {
              "<arg name>": "<arg value>"
            }
          }
        ]
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      },
      {
        "role": "observation",
        "content": "<observation prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response observation>"
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      }
    ]
  }
]

微调源码地址:finetune.py
Loss计算代码:

def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_labels = []
    # batched_conv 是一个数组
    # conv 是数组内的单个 message
    for conv in batched_conv:
        input_ids = [151331, 151333]
        loss_masks = [False, False]
        # conv 是数组内的单个 message
        # message 是 单个role json对象
        for message in conv:
            message = process_message(message)
            # 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
            loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
            # 获取 input 文本的数字表示(ids)
            new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
            # 计算整句的 mask
            new_loss_masks = [loss_mask_val] * len(new_input_ids)
            # 拼接message中的每段json
            input_ids += new_input_ids
            # 拼接message中每段json对应的mask
            loss_masks += new_loss_masks
        # 追加结尾的 token id
        input_ids.append(tokenizer.eos_token_id)
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                # 添加到label,计算loss
                labels.append(input_id)
            else:
                # -100 不处理,即ignore_index
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        # 截断
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])
    return {'input_ids': batched_input_ids, 'labels': batched_labels}

注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
# 数据集拆分遍历
train_dataset = data_manager.get_dataset(
    Split.TRAIN,
    functools.partial(
        process_batch,
        tokenizer=tokenizer,
        max_input_length=ft_config.max_input_length,
        max_output_length=ft_config.max_output_length,
    ),
    batched=True,
)
print('train_dataset:', train_dataset)

Loss计算如下图所示:

总结

相比较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其它的开源模型/微调框架中早已支持该种loss计算,如InternLM、XTuner、Firefly等。对于loss格式的类别,可参考XTuner的官方文档说明:dataset_format.md

原文链接:https://mp.weixin.qq.com/s/0mLCQfpaZr7eEonG4a4Etg

更多大模型相关的文章,请上个人公众号查阅:
image

标签:loss,GLM,9B,ids,length,role,input,message
From: https://www.cnblogs.com/zhiyong-ITNote/p/18243420

相关文章

  • AttributeError: ‘ChatGLMModel‘ object has no attribute ‘prefix_encoder‘
    AttributeError:‘ChatGLMModel‘objecthasnoattribute‘prefix_encoder‘:全面解析问题概述当您使用ChatGLM模型或相关库时遇到AttributeError:‘ChatGLMModel‘objecthasnoattribute‘prefix_encoder‘错误时,这意味着ChatGLMModel类中不存在prefix_encod......
  • 打败GPT-4的最强开源中文大模型GLM-4终于亮相了(附:超详细搭建过程)
    GLM-4是由智谱AI推出的新一代基座预处理大模型,具有与GPT-4相近的性能,尤其在中文能力上可以比肩GPT-4。它在多个方面进行了优化和提升,包括支持更长的上下文长度、更快的推理速度、降低推理成本,以及增强了智能体能力。GLM-4能够处理128k的上下文窗口长度,单次提示词可以处理的文......
  • 深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss
    深入理解交叉熵损失CrossEntropyLoss-CrossEntropyLossflyfish本系列的主要内容是在2017年所写,GPT使用了交叉熵损失函数,所以就温故而知新,文中代码又用新版的PyTorch写了一遍,在看交叉熵损失函数遇到问题时,可先看链接提供的基础知识,可以有更深的理解。深入理解交叉熵损......
  • 本地部署GLM-4-9B清华智谱开源大模型方法和对话效果体验
    GLM-4-9B是清华大学和智谱AI推出的最新一代预训练模型GLM-4系列中的开源版本。在语义、数学、推理、代码和知识等多方面的数据集测评中,GLM-4-9B及其人类偏好对齐的版本GLM-4-9B-Chat均表现出较高的性能,其通用能力评测结果甚至超越了Llama-3-8B开源大模型,多模态版本也与GPT-4版本齐......
  • 【YOLOv8改进】SlideLoss损失函数,解决样本不平衡问题
    YOLO目标检测创新改进与实战案例专栏专栏目录:YOLO有效改进系列及项目实战目录包含卷积,主干注意力,检测头等创新机制以及各种目标检测分割项目实战案例专栏链接:YOLO基础解析+创新改进+实战案例介绍摘要近年来,基于深度学习的人脸检测算法取得了很大进展。这些......
  • 【机器学习】GLM4-9B-Chat大模型/GLM-4V-9B多模态大模型概述、原理及推理实战
    ​​​​​​​目录一、引言二、模型简介2.1GLM4-9B 模型概述2.2GLM4-9B 模型架构三、模型推理3.1GLM4-9B-Chat语言模型3.1.1 model.generate 3.1.2 model.chat3.2GLM-4V-9B多模态模型3.2.1多模态模型概述3.2.2 多模态模型实践四、总结 一、引言......
  • chatglm4 多显卡部署
    importtorchfromtransformersimportAutoModelForCausalLM,AutoTokenizerimportosos.environ['HF_ENDPOINT']='https://hf-mirror.com'#加上这行之后又恢复以前的速度了!device="cuda"print("是否可用:",torch.cuda.is_available())......
  • GLM-4-9B领先!伯克利函数调用榜单BFCL的Function Calling评测方法解析与梳理
    智谱公布的GLM-4-9B基于BFCL榜单的工具调用能力测试结果©作者|格林来源|神州问学在智谱最新开源的GLM-4-9B-Chat中,其工具调用能力在BFCL(伯克利函数调用排行榜)榜上获得了超高的总BFCL分,和gpt-4-turbo-2024-04-09几乎不相上下。在榜单中,还提到了AST总分以及Exec总分两个......
  • GLM-4已经“低调”开源了
    GLM-4-9B是智谱AI推出的最新一代预训练模型GLM-4系列中的开源版本。在语义、数学、推理、代码和知识等多方面的数据集测评中,GLM-4-9B及其人类偏好对齐的版本GLM-4-9B-Chat均表现出较高的性能。除了能进行多轮对话,GLM-4-9B-Chat还具备网页浏览、代码执行、自定义......
  • 复现GLM4-9B
    简介GLM-4-9B是智谱AI推出的最新一代预训练模型GLM-4系列中的开源版本。在语义、数学、推理、代码和知识等多方面的数据集测评中,GLM-4-9B表现出超越Llama-3-8B的卓越性能。除了能进行多轮对话,GLM-4-9B-Chat还具备网页浏览、代码执行、自定义工具调用(FunctionCa......