我正在尝试从某个检查点加载模型并将其用于推理。检查点文件夹如下所示。如何从该文件夹将模型加载到 torch 中。 资源 我可以找到用于从检查点文件加载,而不是文件夹。
import whisper_timestamped as whisper
from transformers import AutoProcessor, WhisperForConditionalGeneration
from peft import prepare_model_for_kbit_training, LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model
from peft import PeftModel, PeftConfig
import torch
from datasets import Dataset, Audio
from transformers import AutoFeatureExtractor, WhisperModel
peft_model_id = "aben118/finetuned_model/checkpoint-3900"
language = "en"
task = "transcribe"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path, load_in_8bit=False, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id)
print(model)
model = model.merge_and_unload()
model.save_pretrained(<model_path>)
但它以
.safetensors
格式保存它。我希望它成为一个可以使用
torch.load
加载的模型。
你几乎完成了! 你的代码的问题在于你正在尝试从 peft 加载一个完整的模型,而不是状态字典。
以下是加载模型的细分,以及如何使用
torch.load
保存它:
import torch
from transformers import WhisperForConditionalGeneration
from peft import PeftModel, PeftConfig
peft_model_id = "aben118/finetuned_model/checkpoint-3900"
# 加载 peft 配置
peft_config = PeftConfig.from_pretrained(peft_model_id)
# 加载基础 Whisper 模型
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path, load_in_8bit=False, device_map="auto"
)
# 加载 peft 权重
model = PeftModel.from_pretrained(model, peft_model_id)
# 这是重要的一步:合并 peft 权重以获得完整的模型
model = model.merge_and_unload()
# 现在你可以像往常一样保存模型
torch.save(model.state_dict(), "<model_path>.pt")
# 加载模型
loaded_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path)
loaded_model.load_state_dict(torch.load("<model_path>.pt"))
解释:
- 加载 peft 配置: 这定义了你如何将 peft 应用于你的基础模型。
- 加载基础模型: 这是你的 Whisper 模型,没有 peft 调整。
- 加载 peft 权重: 这会加载你训练过的 peft 参数。
-
合并权重:
model.merge_and_unload()
将 peft 权重合并到基础模型中。这会给你一个带有 peft 更新参数的完整模型。 -
使用
torch.save
保存: 现在你有了一个包含所有内容的独立模型,你可以使用torch.save
及其状态字典来保存它。 -
加载模型:
你可以创建一个基础 Whisper 模型的实例,并使用
torch.load
加载保存的模型状态字典。
现在你应该能够使用
torch.load
成功加载和使用你的模型进行推理了!