首页 > 编程语言 >用断点调试阅读peft源码:prefix tuning

用断点调试阅读peft源码:prefix tuning

时间:2023-08-07 22:34:34浏览次数:47  
标签:... tuning 断点 prefix 源码 model peft config hidden

今天我们阅读peft源码,主要是为了弄清楚prefix tuning的工作原理和代码细节。

模型定义部分

peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)

# 下载预训练模型T5,模型结构可以在debug console中输入model得到
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

主要是这一句:model = get_peft_model(model, peft_config),所以在这里设置断点。

首先跳转到:peft->mapping.py->get_peft_model函数。我逐行阅读并做出中文注释。

def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
    """
    Returns a Peft model object from a model and a config.

    Args:
        model ([`transformers.PreTrainedModel`]): Model to be wrapped.
        peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
    """
    model_config = getattr(model, "config", {"model_type": "custom"}) # 得到T5模型config,在debug console中输入model_config可以查看
    if hasattr(model_config, "to_dict"):
        model_config = model_config.to_dict()  #把config中的属性序列化为 Python 字典

    peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

    # <TaskType.SEQ_2_SEQ_LM: 'SEQ_2_SEQ_LM'>
    # dict_keys(['SEQ_CLS', 'SEQ_2_SEQ_LM', 'CAUSAL_LM', 'TOKEN_CLS', 'QUESTION_ANS', 'FEATURE_EXTRACTION'])
    if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
        peft_config, PromptLearningConfig
    ):
        return PeftModel(model, peft_config, adapter_name=adapter_name)
    if isinstance(peft_config, PromptLearningConfig):
        peft_config = _prepare_prompt_learning_config(peft_config, model_config)
    return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)

我们从最后一句跳进去,来到了peft->peft_model.py->PeftModelForSeq2SeqLM(PeftModel)类,所以我猜测MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]定义了我们的模型是PeftModelForSeq2SeqLM并且被map到PeftModel,而传入的参数是model, peft_config, adapter_name=adapter_name.

PeftModelForSeq2SeqLM介绍如下:

"""
    Peft model for sequence-to-sequence language modeling.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.


    Example:

        ```py
        >>> from transformers import AutoModelForSeq2SeqLM
        >>> from peft import PeftModelForSeq2SeqLM, get_peft_config

        >>> config = {
        ...     "peft_type": "LORA",
        ...     "task_type": "SEQ_2_SEQ_LM",
        ...     "inference_mode": False,
        ...     "r": 8,
        ...     "target_modules": ["q", "v"],
        ...     "lora_alpha": 32,
        ...     "lora_dropout": 0.1,
        ...     "fan_in_fan_out": False,
        ...     "enable_lora": None,
        ...     "bias": "none",
        ... }

        >>> peft_config = get_peft_config(config)
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
        >>> peft_model = PeftModelForSeq2SeqLM(model, peft_config)
        >>> peft_model.print_trainable_parameters()
        trainable params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566
        ```
    """

以上这个例子简单地用lora微调一个t5-base模型,很便捷!

prefix tuning

找半天没看到prefix tuning的代码,我们直接打开/root/miniconda3/envs/peft-practice/lib/python3.10/site-packages/peft/tuners/prefix_tuning.py查看,发现它改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py

class PrefixEncoder(torch.nn.Module):
    r'''
    The torch.nn model to encode the prefix

    Input shape: (batch-size, prefix-length) 
	prefix-length/num_virtual_tokens:20, hidden_size:768, prefix_hidden_size
    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    '''
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
            )
        else:
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values
    

标签:...,tuning,断点,prefix,源码,model,peft,config,hidden
From: https://www.cnblogs.com/tuyuge/p/17612914.html

相关文章

  • Flink源码解析(零)——源码解析系列随笔说明
    00、博主仅是数据开发及数仓开发工程师,出于提升自身对Flink系统原理掌握考虑,自愿花费精力整理源码解析系列随笔,并非专业Flink系统开发人员,在源码解析过程中出现非专业行为望见谅。希望Flink系统开发专业人员多提意见,不胜感激。01、Flink源码解析系列随笔主要基于Flink1.17.1版本......
  • JMeter源码解析之结果收集器
    一、JMeter结果收集器概述JMeter是在压力领域中最常见的性能测试工具,由于其开源的特点,受到广大测试和开发同学的青睐。但是,在实际应用过程中,JMeter存在的一些性能瓶颈也凸显出来,经常会遇到大并发下压不上去的情况。笔者通过深入分析其源码实现,找到JMeter存在的瓶颈问题及根本原因,为......
  • RTSP/Onvif视频服务器LntonNVR(源码版)视频平台无法通过Onvif控制摄像头云台的问题解决
    LntonNVR视频边缘计算网关平台是我们推出的软硬一体的视频平台,既有软件版本,又有硬件版本。LntonNVR与摄像头连接时,可以通过平台自带的Onvif探测进行设备探测、连接,还能实现对摄像头的PTZ云台控制,包括镜头转向、变焦等操作。通过Onvif控制云台是非常实用的功能,在很多用户实际项目中......
  • 国标GB28181视频平台LntonGBS(源码版)国标视频平台隐藏平台web页面不被访问的操作步骤
    LntonGBS国标视频云服务通过支持国标GB28181协议,实现了设备接入、实时监控直播、录像、语音对讲、云存储、告警、级联等功能。同时,它还支持将接入的视频流以多种格式(包括RTSP、RTMP、FLV、HLS、WebRTC)进行全终端、全平台分发,实现了无插件播放在Web浏览器、手机浏览器、微信端、PC客......
  • RTSP流媒体服务器LntonNVR(源码版)视频平台接入硬盘录像机的具体操作步骤
    LntonNVR是基于RTSP/Onvif协议接入的视频平台,可支持将前端设备的音视频进行采集、传输、处理并分发,实现视频监控直播、云端录像、云存储、检索回看、国标级联、告警等视频能力。平台兼容性高、可拓展性强、性能稳定,可应用在智慧工地、智慧园区、智慧工厂、智慧校园等场景中。对于新......
  • springboot智能3D导诊系统源码,基于规则模板的开发原理
    互联网智慧3D导诊系统源码通过智能导诊,进行自助问询及挂号服务,减轻导诊台护士压力,挂号更加方便快捷。技术架构:springboot+redis+mybatisplus+mysql+RocketMQ  智慧导诊系统开发原理导诊系统从原理上大致可分为基于规则模板和基于数据模型两类。1、基于规则推理的方法通过人工建......
  • 国标GB28181视频平台LntonGBS(源码版)国标平台出现录像无法播放并存在RTMP重复推流现象
    LntonGBS国标视频云服务通过支持国标GB28181协议,实现了设备接入、实时监控直播、录像、语音对讲、云存储、告警、级联等功能。同时,它还支持将接入的视频流以多种格式(包括RTSP、RTMP、FLV、HLS、WebRTC)进行全终端、全平台分发,实现了无插件播放在Web浏览器、手机浏览器、微信端、PC客......
  • JDK8:Lambda表达式使用介绍,Lambda表达式源码及原理分析
    文章目录一、Lambda表达式使用1、Lambda表达式介绍2、Lambda使用规范(1)Lambda基础格式3、Lambda表达式与传统方式比对(1)遍历集合(2)使用Lambda替换匿名内部类使用(3)实现Lambda实现集合排序二、Lambda表达式底层原理解析1、反编译lambda2、静态私有函数生成过程(1)查看内部类的内容3、forE......
  • StampedLock使用及源码分析:号称比读写锁还要快的锁
    文章目录一、StampedLock锁概述1、StampedLock锁简介2、ReentrantReadWriteLock回顾3、ReentrantReadWriteLock导致锁饥饿问题4、锁饥饿问题的缓解5、StampedLock与ReentrantReadWriteLock的对比6、StampedLock特点7、StampedLock的缺点二、StampedLock的使用1、StampedLock的三种......
  • 预约上门系统源码开发,改变服务行业的未来
    预约上门系统源码开发是一项复杂而有挑战性的任务,但也是实现智能化预约服务的关键一步。通过自主开发预约上门系统的源码,企业可以完全定制系统的功能、界面和安全性,从而为用户提供更高效、便捷、个性化的预约体验。本文将带你深入了解预约上门系统源码开发的基本步骤,并提供一些示例......