首页 > 编程语言 >llamafactory/model/loader.py [源码解析]

llamafactory/model/loader.py [源码解析]

时间:2024-07-19 10:54:54浏览次数:23  
标签:args tokenizer llamafactory python py tokens 源码 kwargs model

class TokenizerModule(TypedDict):
    tokenizer: "PreTrainedTokenizer"
    processor: Optional["ProcessorMixin"]


def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
    r"""
    Gets arguments to load config/tokenizer/model.

    Note: including inplace operation of model_args.
    """
    skip_check_imports()
    model_args.model_name_or_path = try_download_model_from_ms(model_args)
    return {
        "trust_remote_code": True,
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.hf_hub_token,
    }

我会逐行解释这段代码:

python

复制

class TokenizerModule(TypedDict):
    tokenizer: "PreTrainedTokenizer"
    processor: Optional["ProcessorMixin"]

这定义了一个名为TokenizerModuleTypedDict。它包含两个字段:

  • tokenizer: 类型为PreTrainedTokenizer
  • processor: 可选类型,为ProcessorMixin
    TypedDict用于定义包含特定类型键的字典。

python

复制

def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:

这定义了一个名为_get_init_kwargs的函数,它接受一个ModelArguments类型的参数,并返回一个字典(键为字符串,值为任意类型)。

python

复制

    r"""
    Gets arguments to load config/tokenizer/model.

    Note: including inplace operation of model_args.
    """

这是函数的文档字符串,描述了函数的功能和注意事项。

python

复制

    skip_check_imports()

调用skip_check_imports()函数,可能用于跳过某些导入检查。

python

复制

    model_args.model_name_or_path = try_download_model_from_ms(model_args)

尝试从某个源(可能是Microsoft)下载模型,并更新model_argsmodel_name_or_path属性。

python

复制

    return {
        "trust_remote_code": True,
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.hf_hub_token,
    }

函数返回一个字典,包含以下键值对:

  • "trust_remote_code": 设置为True
  • "cache_dir": 使用model_argscache_dir属性
  • "revision": 使用model_argsmodel_revision属性
  • "token": 使用model_argshf_hub_token属性

这个函数似乎是为了准备加载模型配置、分词器或模型本身所需的参数。它包含了一些安全设置(如trust_remote_code)和版本控制相关的参数(如revision)。同时,它还处理了缓存目录和访问令牌,这些可能用于从Hugging Face模型中心下载模型。

注意,这个函数对model_args进行了原地修改,这在函数的文档字符串中有提到。

def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
    r"""
    Loads pretrained tokenizer.

    Note: including inplace operation of model_args.
    """
    init_kwargs = _get_init_kwargs(model_args)
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            use_fast=model_args.use_fast_tokenizer,
            split_special_tokens=model_args.split_special_tokens,
            padding_side="right",
            **init_kwargs,
        )
    except ValueError:  # try the fast one
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            use_fast=True,
            padding_side="right",
            **init_kwargs,
        )

    if model_args.new_special_tokens is not None:
        num_added_tokens = tokenizer.add_special_tokens(
            dict(additional_special_tokens=model_args.new_special_tokens),
            replace_additional_special_tokens=False,
        )
        logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
        if num_added_tokens > 0 and not model_args.resize_vocab:
            model_args.resize_vocab = True
            logger.warning("New tokens have been added, changed `resize_vocab` to True.")

    patch_tokenizer(tokenizer)

我会继续逐行解释这段代码:

python

复制

def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":

这定义了一个名为load_tokenizer的函数,它接受一个ModelArguments类型的参数,并返回一个TokenizerModule类型的对象。

python

复制

    r"""
    Loads pretrained tokenizer.

    Note: including inplace operation of model_args.
    """

这是函数的文档字符串,描述了函数的功能和注意事项。

python

复制

    init_kwargs = _get_init_kwargs(model_args)

调用之前定义的_get_init_kwargs函数,获取初始化参数。

python

复制

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            use_fast=model_args.use_fast_tokenizer,
            split_special_tokens=model_args.split_special_tokens,
            padding_side="right",
            **init_kwargs,
        )
    except ValueError:  # try the fast one
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            use_fast=True,
            padding_side="right",
            **init_kwargs,
        )

这段代码尝试加载预训练的tokenizer。如果第一次尝试失败(抛出ValueError),它会再次尝试,但强制使用fast tokenizer。

python

复制

    if model_args.new_special_tokens is not None:
        num_added_tokens = tokenizer.add_special_tokens(
            dict(additional_special_tokens=model_args.new_special_tokens),
            replace_additional_special_tokens=False,
        )
        logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
        if num_added_tokens > 0 and not model_args.resize_vocab:
            model_args.resize_vocab = True
            logger.warning("New tokens have been added, changed `resize_vocab` to True.")

如果指定了新的特殊token,这段代码会将它们添加到tokenizer中。如果成功添加了新token并且resize_vocab原本为False,它会将resize_vocab设置为True并记录一个警告。

python

复制

    patch_tokenizer(tokenizer)

这行调用patch_tokenizer函数,可能用于对tokenizer进行一些自定义修改。

这个函数的主要目的是加载和配置一个预训练的tokenizer。它处理了可能的错误,添加了新的特殊token(如果有的话),并可能对tokenizer进行了一些自定义修改。函数的设计考虑到了灵活性和错误处理,使其能够适应不同的模型和配置。

if model_args.visual_inputs:
        try:
            processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
            setattr(processor, "tokenizer", tokenizer)
        except Exception:
            raise ValueError(
                "This multimodal LLM is not supported.\n"
                "Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
                "Download Yi-VL models from: https://huggingface.co/BUAADreamer"
            )
    else:
        processor = None

    return {"tokenizer": tokenizer, "processor": processor}


def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
    r"""
    Loads model config.
    """
    init_kwargs = _get_init_kwargs(model_args)
    return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)

我会继续逐行解释这段代码:

python

复制

if model_args.visual_inputs:

这行检查model_args是否包含visual_inputs属性,并且该属性为真。这可能表示模型需要处理视觉输入。

python

复制

    try:
        processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
        setattr(processor, "tokenizer", tokenizer)

如果模型支持视觉输入,代码尝试加载一个处理器(processor)。然后,它将之前创建的tokenizer作为属性添加到处理器中。

python

复制

    except Exception:
        raise ValueError(
            "This multimodal LLM is not supported.\n"
            "Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
            "Download Yi-VL models from: https://huggingface.co/BUAADreamer"
        )

如果加载处理器失败,代码抛出一个ValueError,提示用户这个多模态LLM不被支持,并建议下载LLaVA-1.5或Yi-VL模型。这里提到了BUAADreamer,与您提供的URL内容相符。

python

复制

else:
    processor = None

如果模型不需要视觉输入,处理器被设置为None。

python

复制

return {"tokenizer": tokenizer, "processor": processor}

函数返回一个字典,包含tokenizer和processor(如果有的话)。

python

复制

def load_config(model_args: "ModelArguments") -> "PretrainedConfig":

这定义了一个新函数load_config,它接受ModelArguments类型的参数,并返回PretrainedConfig类型的对象。

python

复制

    r"""
    Loads model config.
    """

这是函数的文档字符串,描述了函数的功能。

python

复制

    init_kwargs = _get_init_kwargs(model_args)

调用之前定义的_get_init_kwargs函数,获取初始化参数。

python

复制

    return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)

使用AutoConfig.from_pretrained加载模型配置并返回。这使用了模型的名称或路径,以及之前准备的初始化参数。

总的来说,这段代码主要处理了多模态模型的加载过程,特别是处理视觉输入的情况。它还定义了一个用于加载模型配置的辅助函数。代码中提到的BUAADreamer确实与您提供的URL内容相符,显示了代码与特定模型仓库的关联。

标签:args,tokenizer,llamafactory,python,py,tokens,源码,kwargs,model
From: https://blog.csdn.net/sinat_37574187/article/details/140532188

相关文章

  • CentOS6.5yum安装Python3⛱️
    CentOS6.5yum安装Python3⛱️ 1.在CentOS6.5上通过Webtatic仓库安装Python3[root@localhost~]#rpm-Uvhhttps://mirror.webtatic.com/yum/el6/latest.rpm[root@localhost~]#yum-yinstallpython34upython34u-devel2.安装pip[root@localhost~]#yum-yi......
  • 使用Python爬虫下载视频源码
    一、引言(●ˇ∀ˇ●)在当今互联网时代,视频内容已成为人们获取信息和娱乐的重要方式。有时,我们可能希望将这些视频下载到本地,以便在没有网络的情况下观看。本文将介绍如何使用Python编写一个简单的爬虫。二、环境准备......
  • Pytorch模型文件`*.pt`与`*.pth` 的保存与加载
    1.*.pt文件.pt文件保存的是模型的全部,在加载时可以直接赋值给新变量model=torch.load("filename.pt")。具体操作:(1).模型的保存torch.save(model,"Path/filename.pt")(2).模型的加载model=torch.load("filename.pt")注意:torch.load()的参数使用字符串参数。2..p......
  • 【蓝牙】Android 13 蓝牙源码分析
    Android13在蓝牙模块中进行了多项改进和优化。本文将详细分析其核心组件及其工作原理,包括BluetoothManagerService、AdapterService、AdapterProperties、蓝牙连接管理和JNI接口。1.BluetoothManagerServiceBluetoothManagerService是蓝牙管理的核心类,负责启动和停止蓝......
  • Python - Conda - 对比 conda 和 pip
    之前已经写过一篇和工具相关的文章:《工具篇:makeasparrowcmakebuildsystem》,本文继续这个话题,大家可能都用过conda和pip,但是对于他们的区别和关系,可能大家不一定很清楚,本文来尝试做一些总结。一、conda1.1简介conda是一个通用的包管理器,意思是什么语言的包都可以用它进行管......
  • 计算机毕业设计Python+Tensorflow小说推荐系统 K-means聚类推荐算法 深度学习 Kears
    2、基于物品协同过滤推荐算法2.1、基于⽤户的协同过滤算法(UserCF)该算法利⽤⽤户之间的相似性来推荐⽤户感兴趣的信息,个⼈通过合作的机制给予信息相当程度的回应(如评分)并记录下来以达到过滤的⽬的进⽽帮助别⼈筛选信息,回应不⼀定局限于特别感兴趣的,特别不感兴趣信息的纪录也相......
  • 计算机毕业设计PySpark+Django高考志愿填报推荐系统 高考预测 高考大数据分析 Hadoop
    摘要本文旨在设计与实现一个基于Spark的高考志愿填报推荐系统,旨在帮助高考生根据自身成绩和兴趣,精准推荐合适的大学和专业。系统采用大数据处理框架Spark,结合机器学习算法,实现了对高考数据的深度挖掘和分析,为考生提供科学、有效的志愿填报建议。系统捕捉考生个人特征、......
  • Python 文件操作与管理:Open函数、Json与Pickle、Os模块
    1.open函数的使用Python中的open()函数是处理文件的标准方法。它允许你打开一个文件,并对其进行读取、写入或追加操作open(file,mode,encoding)函数的格式:file:文件路径mode:打开方式(读:r写:w读完之后光标停留在最后读取的位置......
  • Python数据获取(网页视频、音频版)
    爬取数据,上一章有介绍,不懂流言私信或者评论交流即可,在Python中编写爬虫通常涉及以下几个步骤:发送HTTP请求:使用requests库向目标网站发送请求。解析网页内容:使用BeautifulSoup从HTML中解析出需要的数据。下载视频文件:使用requests下载视频文件。保存到本地:将下载的视频文件......
  • 0基础学python-17:文件读写
    目录前言文件读写三步走:        打开文件-->读写文件-->关闭文件一、打开文件1.文件位置绝对位置:相对位置:2.open()方法二、读写文件1.读取文件2.写入文件三、关闭文件1.close()2.with语句前言        读写文件是最常见的IO操作。Python内置......