今天我们阅读peft源码,主要是为了弄清楚prefix tuning的工作原理和代码细节。
模型定义部分
peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)
# 下载预训练模型T5,模型结构可以在debug console中输入model得到
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
主要是这一句:model = get_peft_model(model, peft_config)
,所以在这里设置断点。
首先跳转到:peft->mapping.py->get_peft_model函数。我逐行阅读并做出中文注释。
def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
"""
Returns a Peft model object from a model and a config.
Args:
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
"""
model_config = getattr(model, "config", {"model_type": "custom"}) # 得到T5模型config,在debug console中输入model_config可以查看
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict() #把config中的属性序列化为 Python 字典
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
# <TaskType.SEQ_2_SEQ_LM: 'SEQ_2_SEQ_LM'>
# dict_keys(['SEQ_CLS', 'SEQ_2_SEQ_LM', 'CAUSAL_LM', 'TOKEN_CLS', 'QUESTION_ANS', 'FEATURE_EXTRACTION'])
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
peft_config, PromptLearningConfig
):
return PeftModel(model, peft_config, adapter_name=adapter_name)
if isinstance(peft_config, PromptLearningConfig):
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
我们从最后一句跳进去,来到了peft->peft_model.py->PeftModelForSeq2SeqLM(PeftModel)类,所以我猜测MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]
定义了我们的模型是PeftModelForSeq2SeqLM并且被map到PeftModel,而传入的参数是model, peft_config, adapter_name=adapter_name.
PeftModelForSeq2SeqLM介绍如下:
"""
Peft model for sequence-to-sequence language modeling.
Args:
model ([`~transformers.PreTrainedModel`]): Base transformer model.
peft_config ([`PeftConfig`]): Peft config.
Example:
```py
>>> from transformers import AutoModelForSeq2SeqLM
>>> from peft import PeftModelForSeq2SeqLM, get_peft_config
>>> config = {
... "peft_type": "LORA",
... "task_type": "SEQ_2_SEQ_LM",
... "inference_mode": False,
... "r": 8,
... "target_modules": ["q", "v"],
... "lora_alpha": 32,
... "lora_dropout": 0.1,
... "fan_in_fan_out": False,
... "enable_lora": None,
... "bias": "none",
... }
>>> peft_config = get_peft_config(config)
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> peft_model = PeftModelForSeq2SeqLM(model, peft_config)
>>> peft_model.print_trainable_parameters()
trainable params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566
```
"""
以上这个例子简单地用lora微调一个t5-base模型,很便捷!
prefix tuning
找半天没看到prefix tuning的代码,我们直接打开/root/miniconda3/envs/peft-practice/lib/python3.10/site-packages/peft/tuners/prefix_tuning.py查看,发现它改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
class PrefixEncoder(torch.nn.Module):
r'''
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
prefix-length/num_virtual_tokens:20, hidden_size:768, prefix_hidden_size
Output shape: (batch-size, prefix-length, 2*layers*hidden)
'''
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
标签:...,tuning,断点,prefix,源码,model,peft,config,hidden
From: https://www.cnblogs.com/tuyuge/p/17612914.html