首页 > 编程语言 >LLaMA-Factory/scripts/length_cdf.py 源码解析

LLaMA-Factory/scripts/length_cdf.py 源码解析

时间:2024-07-02 20:31:07浏览次数:19  
标签:cdf args name fire py dataset python length 源码

这段代码定义了一个函数 length_cdf,用来计算和打印数据集样本长度的累积分布函数(CDF),并在脚本直接运行时通过 fire 库将该函数暴露为命令行接口。我们逐行解释这段代码:

python

复制

from llmtuner.data import get_dataset
from llmtuner.hparams import get_train_args
from llmtuner.model import load_tokenizer
  • 从 llmtuner 模块中导入 get_datasetget_train_args 和 load_tokenizer 函数。

python

复制

def length_cdf(
    model_name_or_path: str,
    dataset: Optional[str] = "alpaca_en",
    dataset_dir: Optional[str] = "data",
    template: Optional[str] = "default",
    interval: Optional[int] = 1000,
):
  • 定义了一个名为 length_cdf 的函数,接受以下参数:
    • model_name_or_path: 模型的名称或路径(字符串类型)。
    • dataset: 数据集的名称,默认为 "alpaca_en"(可选)。
    • dataset_dir: 数据集的目录,默认为 "data"(可选)。
    • template: 模板名称,默认为 "default"(可选)。
    • interval: 计算长度分布的区间,默认为 1000(可选)。

python

复制

    model_args, data_args, training_args, _, _ = get_train_args(
        dict(
            stage="sft",
            model_name_or_path=model_name_or_path,
            dataset=dataset,
            dataset_dir=dataset_dir,
            template=template,
            cutoff_len=1_000_000,
            output_dir="dummy_dir",
            overwrite_cache=True,
        )
    )
  • 调用 get_train_args 函数获取模型参数、数据参数和训练参数。
  • 传递一个字典作为参数,其中包括:
    • stage: 训练阶段,这里是 "sft"(假设是某种微调)。
    • model_name_or_path: 模型的名称或路径。
    • dataset: 数据集名称。
    • dataset_dir: 数据集目录。
    • template: 模板名称。
    • cutoff_len: 截断长度,设置为 1,000,000
    • output_dir: 输出目录,这里是 "dummy_dir"
    • overwrite_cache: 是否覆盖缓存,设置为 True

python

复制

    tokenizer = load_tokenizer(model_args)
  • 调用 load_tokenizer 函数加载分词器,传入 model_args

python

复制

    trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
  • 调用 get_dataset 函数获取训练数据集,传入 tokenizermodel_argsdata_args 和 training_args,指定 stage 为 "sft"

python

复制

    total_num = len(trainset)
  • 计算训练数据集的总样本数,存储在 total_num 中。

python

复制

    length_dict = defaultdict(int)
  • 创建一个默认值为 0 的字典 length_dict,用来统计样本长度的分布。   
    for sample in tqdm(trainset["input_ids"]):
        length_dict[len(sample) // interval * interval] += 1
  • 遍历训练数据集中的每个样本(input_ids)。
  • 计算每个样本长度所在的区间(len(sample) // interval * interval),并在 length_dict 中进行计数。

python

复制

    length_tuples = list(length_dict.items())
    length_tuples.sort()
  • 将 length_dict 转换为一个包含长度区间和计数的元组列表 length_tuples
  • 对列表进行排序。

python

复制

    count_accu, prob_accu = 0, 0
  • 初始化累积计数 count_accu 和累积概率 prob_accu

python

复制

    for length, count in length_tuples:
        count_accu += count
        prob_accu += count / total_num * 100
        print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
  • 遍历 length_tuples 中的每个长度区间和计数。
  • 更新累积计数 count_accu 和累积概率 prob_accu

python

复制

        print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
  • 打印当前长度区间的累积计数和累积概率。输出格式为:count_accu (prob_accu%) samples have length < length + interval.

python

复制

if __name__ == "__main__":
    fire.Fire(length_cdf)
  • 检查脚本是否作为主程序运行。
  • 如果是,则使用 fire.Fire 将 length_cdf 函数暴露为命令行接口,使得可以通过命令行参数直接调用该函数。

总结

这段代码的主要功能是计算数据集样本长度的累积分布函数(CDF),并打印每个长度区间的累积计数和累积概率。具体步骤包括:

  1. 从命令行参数获取模型和数据集的相关参数。
  2. 加载分词器和数据集。
  3. 统计样本长度的分布。
  4. 计算并打印累积分布。

示例用法

假设你将这段代码保存为 length_cdf.py,可以通过以下命令运行并查看结果:

sh

复制

python length_cdf.py --model_name_or_path "path/to/model" --dataset "custom_dataset" --dataset_dir "path/to/data"

这将根据提供的模型和数据集路径,计算并打印数据集中样本长度的累积分布情况。

为什么要用fire库,有什么好处,用和不用有什么区别

GPT-4o

Poe

使用 fire 库的主要目的是简化命令行接口的创建,使得 Python 脚本可以更容易地通过命令行参数进行调用和控制。下面我们详细解释一下 fire 库的好处,以及使用和不使用 fire 库的区别。

fire 库的好处

  1. 自动生成命令行接口

    • fire 可以自动将 Python 对象(函数、类、模块等)转换为命令行接口。开发者无需手动解析命令行参数,fire 会自动处理这些工作。
  2. 简化代码

    • 使用 fire 只需一行代码就能将函数或类暴露为命令行接口,极大地简化了代码量。
  3. 易用性

    • fire 提供了一致且直观的方式来调用 Python 脚本,降低了使用命令行参数的复杂性。
  4. 灵活性

    • fire 支持多种调用方式,包括函数、类和模块,使得开发者可以根据需要灵活选择如何组织代码。

使用和不使用 fire 的区别

使用 fire 库

使用 fire 库的代码示例如下:

python

复制

import fire

def greet(name="World"):
    print(f"Hello, {name}!")

if __name__ == "__main__":
    fire.Fire(greet)
  • 当你运行这个脚本时,可以通过命令行参数直接调用 greet 函数并传递参数:

sh

复制

python script.py --name John
  • 输出:

    复制

    Hello, John!
    
不使用 fire 库

如果不使用 fire 库,你需要手动解析命令行参数,例如使用 argparse 库:

import argparse

def greet(name="World"):
    print(f"Hello, {name}!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Greet someone.")
    parser.add_argument("--name", type=str, default="World", help="Name of the person to greet.")
    args = parser.parse_args()
    greet(args.name)
  • 当你运行这个脚本时,同样可以通过命令行参数传递参数:

sh

复制

python script.py --name John
  • 输出:

    复制

    Hello, John!
    

对比分析

  • 代码简洁性

    • 使用 fire 库,代码更加简洁,只需一行代码即可实现命令行接口。
    • 不使用 fire 库,需要更多的代码来设置和解析命令行参数。
  • 功能和灵活性

    • fire 提供了自动化和简化的功能,适合快速开发和测试。
    • argparse 等库提供了更多的自定义选项和复杂功能,适合复杂的命令行接口需求。

结论

  • 如果你需要快速创建一个简单的命令行接口,fire 是一个很好的选择,因为它可以极大地简化开发过程。
  • 如果你需要更多的控制和自定义选项,比如处理复杂的参数解析逻辑,argparse 或其他专门的命令行解析库可能更适合。

综上所述,选择是否使用 fire 取决于你的具体需求和项目的复杂性。

标签:cdf,args,name,fire,py,dataset,python,length,源码
From: https://blog.csdn.net/sinat_37574187/article/details/140135063

相关文章

  • Advanced Data Analytics Using Python_ With Machine Learning, Deep Learning and N
    本书提供了使用Python进行高级数据分析的方法,涵盖了机器学习、深度学习和自然语言处理的应用实例。书中详细讲解了如何在不同的数据库环境中进行数据提取、转换和加载(ETL),并探讨了监督学习、无监督学习、深度学习、时间序列分析以及大规模数据分析的相关内容。目录简介为......
  • python+flask计算机毕业设计的家庭收支记账管理系统(程序+开题+论文)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容研究背景随着社会的快速发展和人们生活水平的提高,家庭财务管理变得越来越重要。然而,传统的家庭记账方式往往存在效率低下、易出错、不易统计等问题......
  • python+flask计算机毕业设计的建材店库存管理系统(程序+开题+论文)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容研究背景随着经济的迅速发展和建筑行业的蓬勃兴起,建材市场的竞争日益激烈。建材店作为建筑行业的重要供应链环节,其库存管理水平直接影响着店铺的运......
  • python+flask计算机毕业设计的礼服租赁管理系统(程序+开题+论文)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容研究背景随着人们对个性化、时尚化需求的日益增长,礼服租赁市场逐渐崭露头角。然而,传统的礼服租赁管理方式往往依赖于纸质记录和人工操作,这种方式不......
  • python: create Envircomnet in vscode 创建虚拟机
    先配置python开发环境 1.在搜索栏输入“>"或是用快后键ctrl+shift+P键就会显示”>",再输入"python:"   选择已经安装好的python的版本至当前项目中   2terminal下输入相关命令行操作,选进入虚拟机的的文件夹下cd.venv/scripts再进行激活虚拟机。输入ac......
  • 基于Python实现的简单电影推荐
    ⚠申明:未经许可,禁止以任何形式转载,若要引用,请标注链接地址。全文共计5071字,阅读大概需要10分钟......
  • python系列&deep_study系列:python如何将语音转文字
    python如何将语音转文字python如何将语音转文字在本文中,我们将探讨解决此问题的三种不同方法。方法1:使用SpeechRecognition库方法2:使用GoogleCloudSpeech-to-TextAPI方法3:使用PyAudio库python如何将语音转文字如果在python中将语音转换成文本?在本文......
  • python系列&deep_study系列:Whisper——部署fast-whisper中文语音识别模型
    Whisper——部署fast-whisper中文语音识别模型Whisper——部署fast-whisper中文语音识别模型环境配置准备tiny模型模型转换代码Whisper——部署fast-whisper中文语音识别模型whisper:https://github.com/openai/whisper/tree/main参考文章:WhisperOpenAI开源语音......
  • Python解释器安装
    1、首先,我们进入官网链接如下:www.python.org我们选择页面上方下载按钮这里我们选择3.11.1版本然后选择推荐版本等待下载完毕2、下载完成之后选择自定义安装选择位置和功能这里直接下一步这里选择向所有用户安装python3.11和向环境变量添加python,然后再自定义安装......
  • 【课程设计】基于python的一款简单的计算器
    我们是大二本科生团队,主力两人耗时3天完成了这款计算器的制作。希望大家给我们多多引流!!!!!!欢迎各位优秀的高考学子报考长安大学,报考长安大学电子信息工程专业。欢迎有志于就读信息与通信工程和计算机科学与技术的准研究生报考长安大学!注意:文件我已经打包好了!长安大学直属国家......