第一篇介绍了如何配置最基本的环境并下载了GLM-4-9B-Chat到本地,接下来我们试着将GLM-4-9B-Chat接入LangChain。
LangChain 是一个基于大型语言模型(LLM)开发应用程序的框架。
LangChain 简化了LLM应用程序生命周期的每个阶段:
- 开发:使用 LangChain 的开源构建模块和组件构建应用程序。使用第三方集成(opens in a new tab)和模板(opens in a new tab)快速上手。
- 生产化:使用LangSmith检查、监控和评估你的链条,以便你可以持续优化和自信地部署。
- 部署:使用LangServe(opens in a new tab)将任何链条转变为 API。
LangChain提供了很多LLM的封装,内置了 OpenAI、LLAMA 等大模型的调用接口。具体方法可自行查阅,本教程中使用本地模型接入LangChain。
为了接入本地LLM,我们需要继承Langchain.llms.base.LLM 中的一个子类,重写其中的几个关键函数。
还是在上一篇所使用的 /root/autodl-tmp
目录,新建glm4LLM.py
文件:
from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class ChatGLM4_LLM(LLM):
# 基于本地 ChatGLM4 自定义 LLM 类
tokenizer: AutoTokenizer = None
model: AutoModelForCausalLM = None
gen_kwargs: dict = None
def __init__(self, model_name_or_path: str, gen_kwargs: dict = None):
super().__init__()
print("正在从本地加载模型...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
).eval()
print("完成本地模型的加载")
if gen_kwargs is None:
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
self.gen_kwargs = gen_kwargs
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any) -> str:
messages = [{"role": "user", "content": prompt}]
model_inputs = self.tokenizer.apply_chat_template(
messages, tokenize=True, return_tensors="pt", return_dict=True, add_generation_prompt=True
)
# 将input_ids移动到与模型相同的设备
device = next(self.model.parameters()).device
model_inputs = {key: value.to(device) for key, value in model_inputs.items()}
generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
@property
def _identifying_params(self) -> Dict[str, Any]:
"""返回用于识别LLM的字典,这对于缓存和跟踪目的至关重要。"""
return {
"model_name": "glm-4-9b-chat",
"max_length": self.gen_kwargs.get("max_length"),
"do_sample": self.gen_kwargs.get("do_sample"),
"top_k": self.gen_kwargs.get("top_k"),
}
@property
def _llm_type(self) -> str:
return "glm-4-9b-chat"
然后就可以进行简单的测试了,新建一个python文件testLLM.py
from glm4LLM import ChatGLM4_LLM
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
llm = ChatGLM4_LLM(model_name_or_path="/root/autodl-tmp/ZhipuAI/glm-4-9b-chat", gen_kwargs=gen_kwargs)
print(llm.invoke("你是谁"))
运行该文件,如果输出了回答代表已成功将llm接入LangChain
标签:GLM,self,9B,ids,langchain,LLM,kwargs,model,gen From: https://www.cnblogs.com/tarorat/p/18293967