首页 > 其他分享 >Baichuan2 模型详解,附实验代码复现

Baichuan2 模型详解,附实验代码复现

时间:2024-11-21 13:17:15浏览次数:3  
标签:Baichuan prompt 模型 results 详解 复现 print Baichuan2 model

简介

近年来,大规模语言模型(LLM)领域取得了令人瞩目的进展。语言模型的参数规模从早期的数百万(如 ELMo、GPT-1),发展到如今的数十亿甚至上万亿(如 GPT-3、PaLM 和 Switch Transformers)。随着模型规模的增长,LLM 的能力显著提升,展现出更接近人类的语言流畅性,并能执行多样化的自然语言任务。ChatGPT 的推出进一步证明了 LLM 在生成类人文本方面的强大能力,引起了广泛关注。然而,当前主流的 LLM(如 GPT-4、PaLM-2、Claude 等)大多为闭源,这限制了研究人员对模型的深入研究和改进。相较之下,Meta 开发的 LLaMA 系列和其他开源 LLM(如 OPT、Bloom、MPT、Falcon 等)为社区提供了自由访问和实验的机会,加速了该领域的研究与发展。

尽管开源 LLM 推动了技术进步,但大多数开源模型主要面向英语,限制了其他语言(如中文)的模型开发与应用。在此背景下,Baichuan 2 模型应运而生。

Baichuan 2 是一系列大规模多语言语言模型,包括 Baichuan 2-7B(70 亿参数)和 Baichuan 2-13B(130 亿参数)。它们在 2.6 万亿标记上训练,比 Baichuan 1 的数据量增加一倍多,在通用基准(如 MMLU、CMMLU、C-Eval)上表现出显著提升,尤其是在数学和代码问题的解决能力上翻倍。同时,Baichuan 2 在医学和法律领域的任务中表现优异,是针对专业领域优化的理想基础模型。

此外,Baichuan 2 还推出了对话优化版本 Baichuan 2-7B-ChatBaichuan 2-13B-Chat,在指令跟随和对话理解方面表现出色。为推动社区研究,Baichuan 2 开放了从2000亿到2.6万亿标记的各阶段训练检查点,帮助研究者了解模型训练动态,并推动更负责任的 LLM 开发。

Baichuan 2 的基础模型和对话模型现已开源,可用于研究和商业目的。

代码仓库地址:https://github.com/baichuan-inc/Baichuan2

论文创新点

训练数据

数据来源:在数据采集过程中,Baichuan2通过从多种来源获取数据,包括互联网网页、书籍、研究论文、代码库等,构建了一个涵盖广泛世界知识的训练语料体系。训练语料的具体构成如图1所示。Baichuan2使用了2.6T tokens,是Baichuan1的两倍,训练的数据分布中,技术、商业和娱乐类别的语料最多。

image-20241116190210344

数据处理:在数据处理环节,研究者重点关注数据的频率与质量。利用大规模去重与聚类系统,通过支持类似 LSH 的特性和密集嵌入特性,实现对万亿级数据的快速聚类和去重(数小时内完成)。在聚类基础上,对文档、段落和句子进行去重和评分,并根据评分进行预训练阶段的数据采样。各阶段处理后的训练数据规模如图2所示。

image-20241116190242449

模型架构

Tokenizer

Baichuan 2 的分词器设计平衡了高压缩率与合适的词汇规模,以提升推理效率和词嵌入训练质量。相比 Baichuan 1,将词汇表从 64,000 扩展到 125,696,以优化计算效率和模型性能。分词采用基于 SentencePiece 的字节对编码(BPE),未对输入文本进行归一化,也未添加虚拟前缀;数字被拆分为单独的数字编码,对代码数据中的多余空格通过添加仅包含空格的分词来处理。字符覆盖率设为 0.9999,稀有字符回退到 UTF-8 字节,最大分词长度设为 32 以适应长中文短语。分词器的训练数据源自 Baichuan 2 的预训练语料,增加了代码示例和学术论文的采样以改善覆盖率。表 2 显示了Baichuan2 的分词器与其他分词器的详细比较。

image-20241116190653635

位置嵌入

Baichuan 2 的设计中,论文基于 Baichuan 1,为 Baichuan 2-7B 采用 Rotary Positional Embedding (RoPE),为 Baichuan 2-13B 采用 ALiBi 作为位置编码方案。ALiBi 是一种更先进的技术,在外推性能上表现更优。然而,大多数开源模型使用基于乘法的 RoPE,并与 Flash Attention 等优化的注意力实现更为兼容,因为 RoPE 避免了在注意力操作中传递 attention_mask 的需求。尽管如此,初步实验表明,位置编码的选择对模型性能影响不显著。为了支持基于偏差和乘法的注意力机制研究,论文在 Baichuan 2-7B 中应用 RoPE,在 Baichuan 2-13B 中应用 ALiBi,与 Baichuan 1 保持一致。

激活和标准化

Baichuan 2 采用了 SwiGLU 激活函数,这是一种改进的 GLU 变体,能够显著提升性能。由于 SwiGLU 包含三组参数矩阵,与普通 Transformer 前馈层的两组矩阵不同,模型将隐藏层大小从原来的 4 倍调整为 8 3 \frac{8}{3} 38​ 倍,并四舍五入为 128 的倍数以优化计算。注意力层采用由 xFormers 实现的 高效注意力机制,支持 ALiBi 的偏置位置编码,同时减少了内存开销,提升了训练的性能和效率。在 Transformer 块的输入处应用 Layer Normalization,增强了对学习率预热的鲁棒性。此外,模型使用了 RMSNorm 实现,仅计算输入特征的方差以进一步提高效率。

优化策略

Baichuan 2 使用 AdamW 优化器进行训练,其中 β 1 \beta_1 β1​ 和 β 2 \beta_2 β2​ 分别设置为 0.9 和 0.95,权重衰减系数为 0.1,梯度范数裁剪为 0.5。模型采用 2,000 步线性学习率预热,达到最大学习率后应用余弦衰减至最小学习率。参数细节和学习率设定详见表3。

模型训练采用 BFloat16 混合精度模式。相较于 Float16,BFloat16 提供更大的动态范围,在处理大型模型训练中的大值时更加鲁棒。然而,BFloat16 的低精度在某些场景中可能引发问题,例如在一些公开的 RoPE 和 ALiBi 实现中,当整数值超过 256 时,torch.arange 操作可能因冲突无法区分相邻位置。为解决此问题,对位置嵌入等对数值敏感的操作使用全精度计算。

为稳定训练并提升模型性能,引入了 NormHead 机制,对输出嵌入(head)进行归一化。实验表明,NormHead 有两大优势。首先,在初步实验中发现输出嵌入的范数易出现不稳定,稀有词的嵌入范数在训练中趋于减小,从而干扰训练动态。NormHead 能显著稳定这种动态。其次,语义信息主要通过嵌入的余弦相似性编码,而线性分类器通过点积计算 logits 时会混合 L2 距离和余弦相似性,NormHead 减少了 L2 距离对 logits 计算的干扰。

训练过程中,模型 logits 可能变得过大。尽管 softmax 仅依赖 logits 的相对值,但在推理阶段,这种情况会导致重复惩罚机制(如 Hugging Face 的实现)直接对 logits 施加标量(如 1.1 或 1.2),进而显著改变 softmax 后的概率分布,使模型对超参数的选择更敏感。为解决此问题,借鉴 NormSoftmax 和 PaLM 的辅助 z − l o s s z-loss z−loss,模型添加了 m a x − z l o s s max-z loss max−zloss,用于对 logits 进行归一化。公式为:

L max-z = 2 e − 4 × z 2 L_{\text{max-z}} = 2 e^{-4} \times z^2 Lmax-z​=2e−4×z2
其中 z z z​ 为最大 logit 值。此方法稳定了训练过程,并增强了推理时对超参数的鲁棒性。

image-20241116191453977

Scaling Laws

神经网络的缩放定律表明,误差随着训练数据集规模或模型规模的幂函数关系减少。这一规律为深度学习和大规模语言模型的高成本训练提供了性能保障。在训练数十亿参数规模的大型语言模型之前,通常先训练较小规模的模型,并通过这些模型的结果拟合缩放定律,为更大模型的训练提供参考。

Baichuan 2 在模型规模从 10M 到 3B 范围内进行了一系列实验,这些规模从最终模型大小的 1/1000 到 1/10 不等。每个模型使用一致的超参数和来自 Baichuan 2 的同一数据集,训练规模最大达到 1 万亿标记。基于不同规模模型的最终损失,建立了从训练计算量(flops)到目标损失的映射关系。

为了拟合模型的缩放定律,采用了 Henighan 等人提出的公式:

L C = a ⋅ C b + L ∞ L_C = a \cdot C^b + L_\infty LC​=a⋅Cb+L∞​
其中, L ∞ L_\infty L∞​ 表示不可减少的损失,第一项为可减少的损失,采用幂律缩放形式表示; C C C 为训练计算量, L C L_C LC​​ 为在该计算量下的最终损失值。参数通过 SciPy 库的 curve_fit 函数进行拟合。最终拟合得到的缩放曲线以及对 7B 和 13B 参数模型的最终损失预测结果如图4所示。结果表明,拟合的缩放定律可以高度准确地预测 Baichuan 2 的最终损失。

image-20241116191845956

预训练策略

当前大规模语言模型的训练需要高效利用现有 GPU 资源。为此,Baichuan 2 的训练采用了一种协同设计的方法,结合弹性训练框架与智能集群调度策略。在共享 GPU 集群中,由于任务行为的不确定性,常常会出现节点空闲的现象。因此,设计重点在于实现机器级弹性,允许根据集群状态动态调整任务资源,从而支持智能调度算法。为满足弹性需求,训练框架整合了张量并行(Tensor Parallelism)与基于 ZeRO 的数据并行(ZeRO-powered Data Parallelism)。其中,每台机器内部采用张量并行,而跨机器则使用 ZeRO 数据并行以支持弹性扩展。

此外,为减少内存峰值消耗,框架引入张量分割技术,例如对大词汇表的交叉熵计算进行分割,在不增加额外计算和通信的情况下满足内存需求,提高系统效率。混合精度训练也被应用,其中前向与反向计算使用 BFloat16 精度,而优化器更新则使用 Float32 精度,以加速训练同时保证模型精度。

在扩展到数千个 GPU 的大规模训练集群时,为避免通信效率的下降,框架集成了以下技术:

  1. 拓扑感知的分布式训练:通过在多层交换机网络中优化训练任务的分配,减少跨交换机访问的频率,从而降低延迟并提升通信效率。

  2. ZeRO 的混合分区与分层分区:通过在 GPU 间分区参数来降低内存消耗,但这种方式增加了全聚合通信的负担。为应对这种通信瓶颈,框架采用混合与分层分区策略,首先在所有 GPU 间分区优化器状态,然后根据需求动态决定哪些层需要启用 ZeRO3 或是否分层分区参数。

这些优化策略使得 Baichuan 2-7B 和 Baichuan 2-13B 模型能够在 1,024 个 NVIDIA A800 GPU 上高效训练,达到超过 180 TFLOPS 的计算效率。

对齐策略

Baichuan 2 的对齐过程引入了两种主要方法:监督微调(SFT)基于人类反馈的强化学习(RLHF),最终产生了两个对话模型 Baichuan 2-7B-Chat 和 Baichuan 2-13B-Chat。

监督微调(SFT)

在监督微调阶段,利用人工标注对从多种数据来源收集的提示进行分类,依据有用性和无害性原则对每条提示进行打分。为保证数据质量,通过交叉验证机制,使用权威标注者检查特定标注组的样本批次,拒绝不符合质量标准的批次。最终收集了超过 10 万条监督微调样本,并在此基础上训练了基础模型。

奖励模型

为了进一步优化,Baichuan 2 设计了三层分类体系,包括 6 个主类别、30 个次类别以及超过 200 个细分类别,力求从用户视角全面覆盖需求,同时从奖励模型训练的角度确保各类别提示的多样性。为增强响应的多样性,提示的回答由 Baichuan 2 不同规模和不同训练阶段(如 SFT、PPO)生成,奖励模型的训练仅使用 Baichuan 2 模型家族生成的回答,而不使用其他开源或专有模型生成的数据。这种一致性进一步凸显了 Baichuan 2 模型家族的内在协调性。奖励模型的训练损失函数与 InstructGPT 类似,其结果表明,模型对分数差异较大的回答具有更高的区分准确性,与 LLaMA 2 表现一致。

image-20241116223118564

基于 PPO 的强化学习

完成奖励模型的训练后,利用 PPO 算法对语言模型进行强化学习训练。训练过程涉及四个模型:生成响应的 actor 模型、用于计算 KL 惩罚的固定参数 reference 模型、提供整体奖励的固定参数 reward 模型 以及学习逐标记值的 critic 模型

image-20241116223104498

训练细节

在 RLHF 训练过程中,critic 模型首先经过 20 步的预热训练,然后与 actor 模型通过标准 PPO 算法共同更新。训练采用梯度裁剪(0.5)、常数学习率(5e-6)、PPO 剪辑阈值(ε = 0.1)以及 KL 惩罚系数(β 从 0.2 衰减至 0.005)。模型共训练 350 个迭代周期,最终生成了 Baichuan 2-7B-Chat 和 Baichuan 2-13B-Chat 模型。

数据集上的评价指标得分

整体表现

  1. MMLU (Massive Multitask Language Understanding):由多个学术科目的多项选择题组成,用于评估语言模型的广泛知识水平。

  2. C-Eval:一个全面的中文评估基准,包括超过 10,000 道多项选择题,用于测试模型的中文语言能力。

  3. CMMLU:专为评估语言模型在中文语言和文化背景下的知识和推理能力设计的通用评估基准。

  4. AGIEval:以人为中心的评估基准,用于测试语言模型在认知和问题解决等通用能力方面的表现。

  5. Gaokao:利用中国高中入学考试问题的评估框架,用于衡量语言模型在学术领域的知识深度。

  6. BBH (BIG-Bench Hard):一组具有挑战性的任务集合,测试语言模型是否能超越平均人类评审员的表现。

  7. GSM8K:专注于数学问题的评估基准,用于测试语言模型的数学推理能力。

  8. HumanEval:一个从文档字符串到代码的评估数据集,包含 164 个编程问题,用于测试编程逻辑的多方面能力。

image-20241116223207894

垂直领域评估

  1. JEC-QA:从中国国家司法考试中收集的多项选择题和多选题,用于评估法律领域的知识水平,仅测试多项选择题。

  2. MedQA (USMLE):从美国和中国的专业医学执业考试(USMLE 和 MCMLE)收集的数据集,评估临床知识和医学推理能力。

  3. MedMCQA:来自印度医学入学考试的多项选择题数据集,用于测试医学领域的基础和临床知识,报告开发集的成绩。

  4. C-Eval (val):中文评估基准的医学相关学科子集,包括临床医学和基础医学,用于测试模型在医学领域的表现。

  5. MMLU (medical disciplines):评估基准中的医学相关学科,包括临床知识、解剖学、生物学、医学遗传学、营养学和病毒学。

  6. CMMLU (medical disciplines):中文医学相关学科的评估基准,包括解剖学、临床知识、中医等,用于测试中文医学推理能力。

image-20241116223440488

数学和代码

  1. GSM8K:一个包含 8,000 道数学问题的评估基准,使用 4-shot 测试数学推理能力。

  2. MATH:一个包含 12,500 道难度更高的数学问题的数据集,使用 4-shot 测试模型的数学解题能力。

  3. HumanEval:一个零样本编程任务评估基准,包括语言理解、推理、算法和简单数学问题,用于测试模型的编程正确性和问题解决能力。

  4. MBPP (The ManyBabies Python Benchmark):包含 974 个 Python 短函数和程序文本描述,以及用于验证功能正确性的测试用例,采用 3-shot 测试模型的编程能力。

image-20241116223613264

多语言评估

  1. Flores-101:一个多语言评估基准,覆盖全球 101 种语言,数据来源包括新闻、旅行指南和书籍等领域,用于测试多语言能力。

  2. zh-en:中文到英文翻译任务,用于评估跨语言的翻译能力。

  3. zh-fr:中文到法文翻译任务,用于测试模型在中文和法文之间的翻译表现。

  4. zh-es:中文到西班牙文翻译任务,测试模型在中文和西班牙文之间的翻译能力。

  5. zh-ar:中文到阿拉伯文翻译任务,用于测试模型处理中文与阿拉伯文的翻译水平。

  6. zh-ru:中文到俄文翻译任务,用于评估中文和俄文之间的翻译能力。

  7. zh-ja:中文到日文翻译任务,测试模型在中文和日文之间的翻译准确性。

  8. zh-de:中文到德文翻译任务,用于评估模型在中文和德文之间的翻译表现。

image-20241116223655460

代码复现

代码已经全部开源,仓库地址:https://github.com/ResDream/BaichuanEval

选取了三个数据集进行测评,C-Eval和MMLU两个用于评估整体能力的数据集、gsm8k一个用于评估数学能力的数据集。

使用Transformers库和MindNLP分别进行实验的复现。

模型选择Baichuan-7B和Beichuan2-7B-Base

C-Eval

原论文结果为:

ModelC-Eval
Baichuan-7B42.80%
Beichuan2-7B-Base54.00%

复现结果为:

ModelC-Eval
Baichuan-7B (Transformers)43.25%
Baichuan-7B (MindNLP)43.34%
Baichuan2-7B-Base (Transformers)57.65%

代码思路(MindNLP为例,使用Transformers的代码请去代码仓库里找):

对于选择题任务,我们使用模型的logits(最后输出的概率分布)预测答案。提取A, B, C, D四个选项的logits值。使用softmax计算每个选项的概率,并选择概率最大的选项作为预测答案。另外我们使用5-shot设置,在给出选择题问题之前给出5个样例。

import argparse
import json
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from tqdm import tqdm
import numpy as np
import mindspore
from datasets import load_dataset
from mindnlp.transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase, BaiChuanForCausalLM,
)


def parse_argument():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="baichuan-inc/Baichuan-7B",  # 修改此处
        help="model name or path"
    )
    parser.add_argument(
        "--shot", type=int, default=5, help="number of shot for few-shot learning"
    )
    parser.add_argument(
        "--split", type=str, default="val", help="split of dataset to evaluate"
    )
    parser.add_argument(
        "--output_dir", type=str, default="results", help="output directory"
    )
    return parser.parse_args()


class CEval:
    DATA_PATH = "ceval/ceval-exam"
    TASK2DESC = {
        "high_school_physics": "高中物理",
        "fire_engineer": "注册消防工程师",
        "computer_network": "计算机网络",
        "advanced_mathematics": "高等数学",
        "logic": "逻辑学",
        "middle_school_physics": "初中物理",
        "clinical_medicine": "临床医学",
        "probability_and_statistics": "概率统计",
        "ideological_and_moral_cultivation": "思想道德修养与法律基础",
        "operating_system": "操作系统",
        "middle_school_mathematics": "初中数学",
        "chinese_language_and_literature": "中国语言文学",
        "electrical_engineer": "注册电气工程师",
        "business_administration": "工商管理",
        "high_school_geography": "高中地理",
        "modern_chinese_history": "近代史纲要",
        "legal_professional": "法律职业资格",
        "middle_school_geography": "初中地理",
        "middle_school_chemistry": "初中化学",
        "high_school_biology": "高中生物",
        "high_school_chemistry": "高中化学",
        "physician": "医师资格",
        "high_school_chinese": "高中语文",
        "tax_accountant": "税务师",
        "high_school_history": "高中历史",
        "mao_zedong_thought": "毛泽东思想和中国特色社会主义理论概论",
        "high_school_mathematics": "高中数学",
        "professional_tour_guide": "导游资格",
        "veterinary_medicine": "兽医学",
        "environmental_impact_assessment_engineer": "环境影响评价工程师",
        "basic_medicine": "基础医学",
        "education_science": "教育学",
        "urban_and_rural_planner": "注册城乡规划师",
        "middle_school_biology": "初中生物",
        "plant_protection": "植物保护",
        "middle_school_history": "初中历史",
        "high_school_politics": "高中政治",
        "metrology_engineer": "注册计量师",
        "art_studies": "艺术  学",
        "college_economics": "大学经济学",
        "college_chemistry": "大学化学",
        "law": "法学",
        "sports_science": "体育学",
        "civil_servant": "公务员",
        "college_programming": "大学编程",
        "middle_school_politics": "初中政治",
        "teacher_qualification": "教师资格",
        "computer_architecture": "计算机组成",
        "college_physics": "大学物理",
        "discrete_mathematics": "离散数学",
        "marxism": "马克思主义基本原理",
        "accountant": "注册会计师",
    }

    def __init__(
            self,
            model: PreTrainedModel,
            tokenizer: PreTrainedTokenizerBase,
            output_dir: str,
    ):
        self.model = model
        self.tokenizer = tokenizer
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        self.output_dir = output_dir

    def run(self, shot: int, split: str):
        results, accs = {}, {}

        # run all task
        for task_name in self.TASK2DESC:
            print("=" * 100)
            print(f"run task: {task_name}")
            result, acc = self.run_single_task(task_name, shot, split)
            results[task_name] = result
            accs[task_name] = acc
            result_path = os.path.join(self.output_dir, f"{task_name}.json")
            with open(result_path, "w") as f:
                json.dump(result, f, indent=2)
            print(f"save result to {result_path}")

        # results
        acc_path = os.path.join(self.output_dir, "acc.json")
        with open(acc_path, "w") as f:
            json.dump(accs, f, indent=2)
        average_acc = sum(accs.values()) / len(accs)
        print(f"average acc: {average_acc}")

    def run_single_task(self, task_name: str, shot: int, split: str):
        dataset = load_dataset(
            self.DATA_PATH,
            task_name,
            trust_remote_code=True,
            force_download=True
        )
        results = []
        acc = 0
        for data in tqdm(dataset[split]):
            prompt = f"以下是中国关于{self.TASK2DESC[task_name]}考试的单项选择题,请选出其中的正确答案。\n"
            if shot != 0:
                shuffled = dataset["dev"].shuffle()
                for i in range(min(shot, len(shuffled))):
                    prompt += "\n" + self.build_example(shuffled[i], with_answer=True)
            prompt += "\n" + self.build_example(data, with_answer=False)

            input_ids = self.tokenizer.encode(prompt, return_tensors="ms")

            # 获取logits
            logits = self.model(input_ids=input_ids).logits[:, -1].flatten()

            # 获取候选项的logits
            candidate_logits = mindspore.tensor(
                [logits[self.tokenizer(label).input_ids[-1]] for label in ["A", "B", "C", "D"]])

            pred_idx = mindspore.ops.softmax(candidate_logits, axis=0).argmax().item()
            answer = {0: "A", 1: "B", 2: "C", 3: "D"}[pred_idx]

            results.append({
                "prompt": prompt,
                "correct": answer == data["answer"].strip().upper(),
                "answer": answer,
            })
            acc += answer == data["answer"].strip().upper()

        acc /= len(dataset[split])
        return results, acc

    def build_example(self, data, with_answer: bool = True):
        question = data["question"]
        choice = "\n".join(
            [
                "A. " + data["A"],
                "B. " + data["B"],
                "C. " + data["C"],
                "D. " + data["D"],
            ]
        )
        answer = data["answer"].strip().upper() if with_answer else ""
        return f"{question}\n{choice}\n答案:{answer}"


def main():
    args = parse_argument()

    model = BaiChuanForCausalLM.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        ms_dtype=mindspore.bfloat16,
        mirror="huggingface",
        size="7b"
    )

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        use_fast=True,
        add_bos_token=False,
        add_eos_token=False,
        padding_side="left",
    )
    ceval = CEval(model, tokenizer, args.output_dir)
    ceval.run(args.shot, args.split)


if __name__ == "__main__":
    main()

MMLU

原论文结果为:

ModelMMLU
Baichuan-7B42.30%
Beichuan2-7B-Base54.16%

复现结果为:

ModelMMLU
Baichuan-7B (Transformers)35.17%
Baichuan-7B (MindNLP)36.58%
Baichuan2-7B-Base (Transformers)52.15%

代码思路和C-Eval一致,同样使用模型的logits(最后输出的概率分布)预测答案,使用5-shot设置,在给出选择题问题之前给出5个样例。

import os
import argparse
import mindspore
import numpy as np
import json
import os
from datetime import datetime
from pathlib import Path
from categories import subcategories, categories
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer,BaiChuanForCausalLM
from datasets import load_dataset




# 常量定义保持不变
choices = ["A", "B", "C", "D"]
SUBJECTS = [
    'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
    'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
    'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
    'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
    'global_facts', 'high_school_biology', 'high_school_chemistry',
    'high_school_computer_science', 'high_school_european_history', 'high_school_geography',
    'high_school_government_and_politics', 'high_school_macroeconomics',
    'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics',
    'high_school_psychology', 'high_school_statistics', 'high_school_us_history',
    'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law',
    'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing',
    'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition',
    'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
    'professional_medicine', 'professional_psychology', 'public_relations',
    'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions'
]


def format_subject(subject):
    l = subject.split("_")
    return " ".join(l)


def format_example(example, include_answer=True):
    """格式化单个示例"""
    prompt = example['question']
    for j, choice in enumerate(choices):
        prompt += f"\n{choice}. {example[f'choices'][j]}"
    prompt += "\nAnswer:"
    if include_answer:
        prompt += f" {example['answer']}\n\n"
    return prompt


def gen_prompt(train_examples, subject, k=-1):
    """生成提示文本"""
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )
    if k == -1:
        k = len(train_examples)
    for i in range(k):
        prompt += format_example(train_examples[i])
    return prompt


# @mindspore.jit
def eval(args, subject, model, tokenizer, dev_examples, test_examples):
    cors = []
    all_probs = []
    predictions = []  # 存储详细预测结果

    letter_to_number = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
    number_to_letter = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}

    print(f"\n{'=' * 20} 正在评估 {subject} {'=' * 20}")

    for i, example in enumerate(test_examples):
        k = args.ntrain
        prompt_end = format_example(example, include_answer=False)
        train_prompt = gen_prompt(dev_examples, subject, k)
        prompt = train_prompt + prompt_end

        inputs = tokenizer(prompt, return_tensors="ms")

        if inputs.input_ids.shape[-1] > model.config.max_position_embeddings:
            while inputs.input_ids.shape[-1] > model.config.max_position_embeddings:
                k -= 1
                train_prompt = gen_prompt(dev_examples, subject, k)
                prompt = train_prompt + prompt_end
                inputs = tokenizer(prompt, return_tensors="ms")

        outputs = model(**inputs)
        logits = outputs.logits[0, -1]

        choice_probs = []
        for choice in choices:
            choice_id = tokenizer.encode(" " + choice, add_special_tokens=False)[0]
            choice_probs.append(logits[choice_id].item())

        probs = mindspore.ops.softmax(mindspore.tensor(choice_probs), axis=0).numpy()
        pred_letter = choices[np.argmax(choice_probs)]
        pred_number = letter_to_number[pred_letter]

        label = int(example['answer'])
        label_letter = number_to_letter[label]
        cor = pred_number == label

        cors.append(cor)
        all_probs.append(probs)

        # 存储详细预测结果
        predictions.append({
            'question': example['question'],
            'choices': example['choices'],
            'prediction': pred_letter,
            'correct_answer': label_letter,
            'probabilities': probs.tolist(),
            'correct': cor
        })

    acc = np.mean(cors)
    print(f"\n科目: {subject}")
    print(f"平均准确率: {acc:.3f}")
    print(f"总样本数: {len(test_examples)}")
    print(f"正确预测数: {sum(cors)}")
    print("=" * 50)

    return {
        'accuracy': acc,
        'total_examples': len(test_examples),
        'correct_count': sum(cors),
        'detailed_predictions': predictions,
        'raw_cors': cors,
        'probabilities': [p.tolist() for p in all_probs]
    }


def save_results(args, results_data, timestamp):
    """保存评估结果到JSON文件"""
    try:
        # 创建保存目录
        save_dir = Path(args.save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)

        # 创建包含模型名称和时间戳的文件名
        model_name = args.model.split('/')[-1]
        filename = f"mmlu_results_{model_name}_{timestamp}.json"
        save_path = save_dir / filename

        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(results_data, f, indent=2, ensure_ascii=False)
        print(f"\n结果已保存至: {save_path}")

    except Exception as e:
        print(f"\n保存结果时出错: {str(e)}")


def main(args):
    # 记录开始时间和时间戳
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    model = BaiChuanForCausalLM.from_pretrained(
        args.model,
        ms_dtype=mindspore.float16,
        trust_remote_code=True,
        mirror="huggingface",
        size="7b"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.model,
        trust_remote_code=True,
        mirror="huggingface"
    )

    # 初始化结果存储
    results_data = {
        "metadata": {
            "model": args.model,
            "num_train_examples": args.ntrain,
            "timestamp": timestamp,
        },
        "overall_results": {
            "total_correct": 0,
            "total_questions": 0,
            "overall_accuracy": 0
        },
        "subject_results": {},
        "subcategory_results": {
            subcat: {"correct": 0, "total": 0, "accuracy": 0}
            for subcat_lists in subcategories.values()
            for subcat in subcat_lists
        },
        "category_results": {
            cat: {"correct": 0, "total": 0, "accuracy": 0}
            for cat in categories
        }
    }

    # 评估每个科目
    for subject in SUBJECTS:
        # try:
        dataset = load_dataset("cais/mmlu", subject)

        dev_examples = list(dataset['dev'])
        if args.ntrain > len(dev_examples):
            print(f"警告: 要求{args.ntrain}个示例,但{subject}只有{len(dev_examples)}个可用")
            dev_examples = dev_examples[:len(dev_examples)]
        else:
            dev_examples = dev_examples[:args.ntrain]

        test_examples = list(dataset['test'])

        if not dev_examples or not test_examples:
            print(f"跳过{subject} - 未找到示例")
            continue

        # 获取该科目的评估结果
        subject_results = eval(args, subject, model, tokenizer, dev_examples, test_examples)
        results_data["subject_results"][subject] = subject_results

        # 更新总体统计
        results_data["overall_results"]["total_correct"] += subject_results["correct_count"]
        results_data["overall_results"]["total_questions"] += subject_results["total_examples"]

        # 更新子类别和类别统计
        if subject in subcategories:
            subcats = subcategories[subject]
            for subcat in subcats:
                results_data["subcategory_results"][subcat]["correct"] += subject_results["correct_count"]
                results_data["subcategory_results"][subcat]["total"] += subject_results["total_examples"]

                for key in categories.keys():
                    if subcat in categories[key]:
                        results_data["category_results"][key]["correct"] += subject_results["correct_count"]
                        results_data["category_results"][key]["total"] += subject_results["total_examples"]

        # except Exception as e:
        #     print(f"处理科目 {subject} 时出错: {str(e)}")
        #     continue

    # 计算最终准确率
    total_correct = results_data["overall_results"]["total_correct"]
    total_questions = results_data["overall_results"]["total_questions"]
    results_data["overall_results"]["overall_accuracy"] = total_correct / total_questions if total_questions > 0 else 0

    # 计算子类别和类别准确率
    for subcat in results_data["subcategory_results"]:
        total = results_data["subcategory_results"][subcat]["total"]
        if total > 0:
            results_data["subcategory_results"][subcat]["accuracy"] = \
                results_data["subcategory_results"][subcat]["correct"] / total

    for cat in results_data["category_results"]:
        total = results_data["category_results"][cat]["total"]
        if total > 0:
            results_data["category_results"][cat]["accuracy"] = \
                results_data["category_results"][cat]["correct"] / total

    # 打印最终结果
    print("\n" + "=" * 20 + " 最终结果 " + "=" * 20)
    print(f"总体准确率: {results_data['overall_results']['overall_accuracy']:.3f}")
    print(f"总正确数: {total_correct}")
    print(f"总题目数: {total_questions}")

    print("\n子类别结果:")
    for subcat, results in results_data["subcategory_results"].items():
        if results["total"] > 0:
            print(f"{subcat}: {results['accuracy']:.3f}")

    print("\n类别结果:")
    for cat, results in results_data["category_results"].items():
        if results["total"] > 0:
            print(f"{cat}: {results['accuracy']:.3f}")

    # 保存结果
    save_results(args, results_data, timestamp)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--save_dir", "-s", type=str, default="results")
    parser.add_argument(
        "--model",
        "-m",
        type=str,
        default="baichuan-inc/Baichuan-7B"
    )
    args = parser.parse_args()
    main(args)


gsm8k

原论文结果为:

Modelgsm8k
Baichuan-7B9.17%
Beichuan2-7B-Base24.49%

复现结果为:

Modelgsm8k
Baichuan-7B (Transformers)11.43%
Baichuan-7B (MindNLP)11.56%
Baichuan2-7B-Base (Transformers)25.37%

代码思路:

使用4-shot设置,在给出选择题问题之前给出4个样例。对于gsm8k数据集,由于在提示中,我们限制模型回复 The Answer is [\answer] ,所以我们提取所有生成文本中的最后一个数字作为答案。

import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import re
import mindspore
import argparse
import jsonlines
import numpy as np
import datasets
from datasets import load_from_disk, load_dataset
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer, BaiChuanForCausalLM

ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"


def check_prompt_file(prompt_path):
    if not os.path.exists(prompt_path):
        raise FileNotFoundError(f"Prompt file not found at {prompt_path}")
    return open(prompt_path).read()


def doc_to_text(doc, prompt):
    return (
            prompt
            + "\nQuestion: "
            + doc["question"]
            + "\nLet's think step by step\n"
    )


def decode(tokens_list, tokenizer, raw_text_len):
    sents = []
    for tokens in tokens_list:
        tokens = tokens.cpu().numpy().tolist()
        sent = tokenizer.tokenizer.decode(tokens[raw_text_len:])
        sent = sent.split("<|endoftext|>")[0]
        sent = sent.split("\n\n\n")[0]
        sent = sent.split("\n\n")[0]
        sent = sent.split("Question:")[0]
        sents.append(sent)
    return sents


def load_model_and_tokenizer(checkpoint_path):
    print("Loading tokenizer ...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            checkpoint_path, trust_remote_code=True
        )
    except Exception as e:
        raise RuntimeError(f"Failed to load tokenizer: {str(e)}")

    print("Loading model ...")
    try:
        model = BaiChuanForCausalLM.from_pretrained(
            checkpoint_path,
            ms_dtype=mindspore.float16,
            trust_remote_code=True
        ).eval()
        # Move model to GPU
        model.to("cuda:0")
    except Exception as e:
        raise RuntimeError(f"Failed to load model: {str(e)}")

    return model, tokenizer


def generate_sample(model, tokenizer, input_txt):
    input_ids = tokenizer.encode(input_txt)
    raw_text_len = len(input_ids)
    context_enc = mindspore.tensor([input_ids])
    # print(f"Input text: {input_txt}\n")

    try:
        # 直接在generate()中设置参数
        outputs = model.generate(
            context_enc,
            max_length=2048,
            do_sample=False,
        )
        output_text = decode(outputs, tokenizer, raw_text_len)[0]
        # print(f"\nOutput text: {output_text}\n")
        return output_text
    except Exception as e:
        print(f"Generation failed: {str(e)}")
        return ""


def extract_answer(completion):
    # First try to extract answer in the standard format
    match = ANS_RE.search(completion)
    if match:
        try:
            match_str = match.group(1).strip()
            match_str = match_str.replace(",", "")
            return eval(match_str)
        except:
            pass

    # Fall back to looking for the last number in the text
    try:
        last_number = re.findall(r"\d+", completion)[-1]
        return eval(last_number)
    except:
        return INVALID_ANS


def is_correct(completion, answer):
    try:
        gold = extract_answer(answer)
        if gold == INVALID_ANS:
            print("Warning: No ground truth answer found in the document.")
            return False
        return extract_answer(completion) == gold
    except Exception as e:
        print(f"Error comparing answers: {str(e)}")
        return False


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test HF checkpoint.")
    parser.add_argument(
        "-c",
        "--checkpoint-path",
        type=str,
        help="Checkpoint path",
        default="baichuan-inc/Baichuan-7B",
    )
    parser.add_argument("-f", "--sample-input-file", type=str, default=None)
    parser.add_argument(
        "-o", "--sample-output-file", type=str, default="gsm8k_res.jsonl"
    )
    parser.add_argument(
        "-p", "--prompt-file", type=str, default="gsm8k_prompt.txt"
    )

    args = parser.parse_args()

    # Create results directory if it doesn't exist
    if not os.path.exists("results"):
        os.makedirs("results")

    # Update output file path to include results directory
    args.sample_output_file = os.path.join("results", args.sample_output_file)

    # Load and verify prompt file
    try:
        fewshot_prompt = check_prompt_file(args.prompt_file)
    except FileNotFoundError as e:
        print(f"Error: {str(e)}")
        print("Please ensure the prompt file exists at the specified location")
        exit(1)

    # Load dataset
    try:
        if args.sample_input_file is not None:
            dataset = load_from_disk(args.sample_input_file)
        else:
            config = datasets.DownloadConfig(resume_download=True, max_retries=100)
            dataset = load_dataset("gsm8k", "main", download_config=config)
        test = dataset["test"]
    except Exception as e:
        print(f"Failed to load dataset: {str(e)}")
        exit(1)

    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(args.checkpoint_path)

    # Process samples
    f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8"))
    tot_length = test.num_rows
    acc_res = []

    try:
        for i, doc in enumerate(test):
            print(f"\nProcessing sample {i + 1}/{tot_length}")
            context = doc_to_text(doc, fewshot_prompt)
            completion = generate_sample(model, tokenizer, context)
            answer = doc["answer"]
            print(answer)
            acc = is_correct(completion, answer)
            doc["completion"] = completion
            doc["acc"] = acc
            f_output.write(doc)
            acc_res.append(acc)

            if (i + 1) % 10 == 0:
                print(f"Current accuracy: {np.mean(acc_res):.4f}")

    except Exception as e:
        print(f"Error during evaluation: {str(e)}")
    finally:
        f_output.close()
        if acc_res:
            print(f"\nFinal accuracy: {np.mean(acc_res):.4f}")

标签:Baichuan,prompt,模型,results,详解,复现,print,Baichuan2,model
From: https://blog.csdn.net/qq_51957239/article/details/143942202

相关文章

  • 内存函数详解
    1.memcpy使⽤和模拟实现2.memmove使⽤和模拟实现3.memset函数的使⽤4.memcmp函数的使⽤一.memcpy的使用与模拟实现1.定义:注意:1.memcpy返回的是目的地的指针      2.使用时要包含头文件string.h      3.num指的是拷贝的个数(单位为字节)2.......
  • Altenergy电力系统控制软件 status_zigbee SQL注入漏洞复现(CVE-2024-11305)
     0x01阅读须知        技术文章仅供参考,此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失,均由使用......
  • PlantUML+vscode使用详解
    目录PlantUML使用1.Uml图1.1Uml类图1.2类图关系2.PlantUML功能概述2.1PlantUML核心特性2.2PlantUML的优势3.文本定义语言生成图表PlantUML学习指南4.自动转换c#源码工具4.1CSharptoPlantUML(VisualStudioCode扩展)4.2PlantUmlClassDiagramGeneratorNuget地址安装使用示......
  • 最全亚马逊批量实时采集商品链接方法(图文详解),还能看市场集中度!
    第一步:准备关键词、类目ID、店铺ID、ASIN可以通过选关键词、选类目2种方式采集关键词,自己卡条件批量下载下来。第二步:了解以词挖品、类目top1万采集、畅销榜采集、店铺挖品、僵尸链接、asin采集等等以词挖品方法一以词挖品方法二以词挖品方法三以词挖品方法四......
  • HarmonyOS Next加解密算法中的参数与模式详解
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)中加解密算法参数与模式的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。一、加解密参数......
  • 详解线程的三大特性:原子性、可见性和有序性
    在多线程编程中,理解线程的原子性、可见性和有序性是构建正确并发程序的基础。以下是它们的详细解释:1.原子性(Atomicity)定义原子性指的是操作不可被中断,要么全部执行完成,要么完全不执行。特性原子性操作在执行时不会被其他线程干扰。如果多个线程同时访问共享资......
  • 嵌入式硬件电子电路设计(七)稳压二极管-齐纳二极管-齐纳击穿全面详解
    引言:在嵌入式硬件电子电路设计中,稳压二极管(又称齐纳二极管)是一种常用的元件,主要用于电压稳定、过压保护和电路调试。齐纳二极管利用齐纳击穿效应,在反向工作状态下能够维持稳定的电压输出,因此被广泛应用于各种电源电路和信号调理电路中。理解齐纳二极管的工作原理及其在实际电路......
  • JavaScript初识及基本语法详解
    JavaScript是一种轻量级的编程语言,它可以在网页中嵌入,用来控制网页的动态效果和用户交互。JavaScript是所有现代网页浏览器都支持的脚本语言,它可以让网页变得“活”起来,实现各种复杂的功能。JavaScript的基本语法JavaScript的语法基础与Java语言类似,但它是解释型语言,不需要编......
  • Java 值传递详解
    形参&实参方法的定义可能会用到参数(有参的方法),参数在程序语言中分为:实参(实际参数,Arguments):用于传递给函数/方法的参数,必须有确定的值。形参(形式参数,Parameters):用于定义函数/方法,接收实参,不需要有确定的值。Stringhello="Hello!";//hello为实参sayHello(hello);//......
  • Java语法糖详解
    什么是语法糖?语法糖(SyntacticSugar)也称糖衣语法,是英国计算机学家Peter.J.Landin发明的一个术语,指在计算机语言中添加的某种语法,这种语法对语言的功能并没有影响,但是更方便程序员使用。简而言之,语法糖让程序更加简洁,有更高的可读性。 有意思的是,在编程领域,除了语法......