首页 > 其他分享 >关于BGE-M3接入LangChain时遇到的问题与解决方法

关于BGE-M3接入LangChain时遇到的问题与解决方法

时间:2024-07-05 18:08:51浏览次数:13  
标签:BGE File self py LangChain M3 kwargs line root

本文基于https://github.com/datawhalechina/self-llm/blob/master/GLM-4/02-GLM-4-9B-chat%20langchain%20%E6%8E%A5%E5%85%A5.md提供的教程。由于使用本地部署的大模型,在继承LangChain中的LLM类时需要重写几个函数。

但是在具体测试的时候出现了以下的错误

/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1659: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.
  warnings.warn(
Traceback (most recent call last):
  File "/root/autodl-tmp/glm4LLM.py", line 63, in <module>
    print(llm.invoke("你是谁"))
          ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 276, in invoke
    self.generate_prompt(
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 633, in generate_prompt
    return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 803, in generate
    output = self._generate_helper(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 670, in _generate_helper
    raise e
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 657, in _generate_helper
    self._generate(
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 1322, in _generate
    self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
  File "/root/autodl-tmp/glm4LLM.py", line 40, in _call
    generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 1758, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2397, in _sample
    outputs = self(
              ^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 1005, in forward
    transformer_outputs = self.transformer(
                          ^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 887, in forward
    inputs_embeds = self.embedding(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 823, in forward
    words_embeddings = self.word_embeddings(input_ids)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/functional.py", line 2264, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

错误原因主要是因为input_ids(输入数据)与model(模型)所在设备不一致。

经过修改成下面的代码可以成功运行,主要修改了input_ids对应语句。

from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

class ChatGLM4_LLM(LLM):
    # 基于本地 ChatGLM4 自定义 LLM 类
    tokenizer: AutoTokenizer = None
    model: AutoModelForCausalLM = None
    gen_kwargs: dict = None
        
    def __init__(self, model_name_or_path: str, gen_kwargs: dict = None):
        super().__init__()
        print("正在从本地加载模型...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        ).eval()
        print("完成本地模型的加载")
        
        if gen_kwargs is None:
            gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
        self.gen_kwargs = gen_kwargs
        
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None,
              **kwargs: Any) -> str:
        messages = [{"role": "user", "content": prompt}]
        model_inputs = self.tokenizer.apply_chat_template(
            messages, tokenize=True, return_tensors="pt", return_dict=True, add_generation_prompt=True
        )
        
        # 将input_ids移动到与模型相同的设备
        device = next(self.model.parameters()).device
        model_inputs = {key: value.to(device) for key, value in model_inputs.items()}
        
        generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids)
        ]
        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response
    
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """返回用于识别LLM的字典,这对于缓存和跟踪目的至关重要。"""
        return {
            "model_name": "glm-4-9b-chat",
            "max_length": self.gen_kwargs.get("max_length"),
            "do_sample": self.gen_kwargs.get("do_sample"),
            "top_k": self.gen_kwargs.get("top_k"),
        }

    @property
    def _llm_type(self) -> str:
        return "glm-4-9b-chat"

标签:BGE,File,self,py,LangChain,M3,kwargs,line,root
From: https://www.cnblogs.com/tarorat/p/18286378

相关文章

  • 构建LangChain应用程序的示例代码:56、如何实现一个多智能体模拟,其中没有固定的发言顺
    多智能体分散式发言人选择示例展示了如何实现一个多智能体模拟,其中没有固定的发言顺序。智能体自行决定谁来发言,通过竞价机制实现。我们将在下面的示例中展示一场虚构的总统辩论来演示这一过程。导入LangChain相关模块fromtypingimportCallable,Listimporttena......
  • langchain学习之agent
    系列文章目录第一部分langchain入门以及prompt、解析器使用第二部分langchain学习之memory机制第三部分langchain学习之chain机制文章目录系列文章目录前言一、导入需要的库二、准备模型三、使用网上的工具四、代理写代码五、自定义写方法总结前言LangChain......
  • 利用 STM32 实现多协议物联网网关:Modbus/Zigbee 到以太网/Wi-Fi 的数据桥接
    摘要: 随着物联网技术的飞速发展,不同通信协议之间的互联互通成为了构建智能化系统的一大挑战。本文将以实战项目为例,详细介绍如何利用STM32微控制器实现Modbus/Zigbee与以太网/Wi-Fi之间的协议转换,从而打通传感器数据上传至服务器的“最后一公里”。关键词: STM32,协议......
  • 在Ubantu22.04中运行ORB_SLAM3
    在Ubantu22.04中运行ORB_SLAM3一、概述ORB-SLAM3是一个支持视觉、视觉加惯导、混合地图的SLAM系统,可以在单目,双目和RGB-D相机上利用针孔或者鱼眼模型运行。从第一版的单目相机系统,到第二版加入了对stereo以及RGBDcamera的支持,再到目前最新版本的orb-slam整合了visual,以及visual......
  • keil5编译错误之Download failed - “Cortex-M3”
    可能出现的问题1.你设置的debug有点问题,没有把之前的文件清除解决办法:点开setting-》flashdownload 按下面配置注意这个programming algorothm非常重要根据不同的芯片内置flash大小要选择不同的programming,这个的作用是限制你的代码容量使其不超过总flash大小,并且......
  • STM32F1+HAL库+FreeTOTS学习5——内核中断管理及中断控制函数
    STM32F1+HAL库+FreeTOTS学习5——中断管理和临界段代码保护中断简介中断优先级寄存器拓展FreeRTOS中PendSV和Systick中断优先级配置三个中断屏蔽寄存器FreeRTOS中断管理函数代码验证上一期我们学习了FreeRTOS中任务挂起与恢复,在中断服务程序中恢复任务过程中,尤其强调......
  • 在Ubantu22.04中运行ORB_SLAM3并进行源码解析
    在Ubantu22.04中运行ORB_SLAM3并进行源码解析1.ORB_slam3简介ORB-SLAM3是一款前沿的即时定位与建图(SLAM)系统,专为大规模环境下的实时定位与三维重建设计。系统兼容多种视觉传感器配置,包括单目、立体双目以及RGB-D相机。ORB-SLAM3采用OrientedFAST和RotatedBRIEF(ORB)算法进......
  • STM32学习——TIM定时器(1)
    目录1.认识TIM2.定时器介绍2.1基本定时器2.2通用定时器2.3高级定时器3.定时器中断大致结构4.时基单元时序 4.1预分频器时序5.Keil5代码    5.1.部分常用函数5.2.使用流程1.认识TIM    TIM,也就是Timer,定时器。那在开始学习之前呢,先做好心理准备......
  • STM32 低功耗模式 睡眠、停止和待机 详解
    STM32提供了三种低功耗模式,分别是睡眠模式(SleepMode)、停止模式(StopMode)和待机模式(StandbyMode),我们在做一些电池供电项目的时候,低功耗模式显得尤为重要。模式名称进入唤醒唤醒后位置对1.2V域时钟的影响对VDD域时钟的影响功耗睡眠模式WFI任意中断睡眠位置开始执行CPU/CLK......
  • STM32秒表设计【课设4/5】
    引言终于轮到我们最后的八段LED了!作为秒表的眼睛,必不可少的就是显示模块。八段LED初始化直接就叫做LED_Init()吧voidLED_Init(void){GPIO_InitTypeDefled; RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOE,ENABLE); led.GPIO_Mode=GPIO_Mode_IPU; led.GPIO_Pin=GPI......