本文基于https://github.com/datawhalechina/self-llm/blob/master/GLM-4/02-GLM-4-9B-chat%20langchain%20%E6%8E%A5%E5%85%A5.md提供的教程。由于使用本地部署的大模型,在继承LangChain中的LLM类时需要重写几个函数。
但是在具体测试的时候出现了以下的错误
/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1659: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.
warnings.warn(
Traceback (most recent call last):
File "/root/autodl-tmp/glm4LLM.py", line 63, in <module>
print(llm.invoke("你是谁"))
^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 276, in invoke
self.generate_prompt(
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 633, in generate_prompt
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 803, in generate
output = self._generate_helper(
^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 670, in _generate_helper
raise e
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 657, in _generate_helper
self._generate(
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 1322, in _generate
self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
File "/root/autodl-tmp/glm4LLM.py", line 40, in _call
generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 1758, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2397, in _sample
outputs = self(
^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 1005, in forward
transformer_outputs = self.transformer(
^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 887, in forward
inputs_embeds = self.embedding(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 823, in forward
words_embeddings = self.word_embeddings(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 163, in forward
return F.embedding(
^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/functional.py", line 2264, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
错误原因主要是因为input_ids(输入数据)与model(模型)所在设备不一致。
经过修改成下面的代码可以成功运行,主要修改了input_ids对应语句。
from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class ChatGLM4_LLM(LLM):
# 基于本地 ChatGLM4 自定义 LLM 类
tokenizer: AutoTokenizer = None
model: AutoModelForCausalLM = None
gen_kwargs: dict = None
def __init__(self, model_name_or_path: str, gen_kwargs: dict = None):
super().__init__()
print("正在从本地加载模型...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
).eval()
print("完成本地模型的加载")
if gen_kwargs is None:
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
self.gen_kwargs = gen_kwargs
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any) -> str:
messages = [{"role": "user", "content": prompt}]
model_inputs = self.tokenizer.apply_chat_template(
messages, tokenize=True, return_tensors="pt", return_dict=True, add_generation_prompt=True
)
# 将input_ids移动到与模型相同的设备
device = next(self.model.parameters()).device
model_inputs = {key: value.to(device) for key, value in model_inputs.items()}
generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
@property
def _identifying_params(self) -> Dict[str, Any]:
"""返回用于识别LLM的字典,这对于缓存和跟踪目的至关重要。"""
return {
"model_name": "glm-4-9b-chat",
"max_length": self.gen_kwargs.get("max_length"),
"do_sample": self.gen_kwargs.get("do_sample"),
"top_k": self.gen_kwargs.get("top_k"),
}
@property
def _llm_type(self) -> str:
return "glm-4-9b-chat"
标签:BGE,File,self,py,LangChain,M3,kwargs,line,root
From: https://www.cnblogs.com/tarorat/p/18286378