首页 > 其他分享 >文盘rust--使用 Rust 构建RAG

文盘rust--使用 Rust 构建RAG

时间:2024-10-08 14:11:32浏览次数:7  
标签:RAG use -- await pub 文盘 let model config

作者:京东科技 贾世闻

RAG(Retrieval-Augmented Generation)技术在AI生态系统中扮演着至关重要的角色,特别是在提升大型语言模型(LLMs)的准确性和应用范围方面。RAG通过结合检索技术与LLM提示,从各种数据源检索相关信息,并将其与用户的问题结合,生成准确且丰富的回答。这一机制特别适用于需要应对信息不断更新的场景,因为大语言模型所依赖的参数知识本质上是静态的。

RAG技术的优势在于它能够利用外部知识库,引用大量的信息,以提供更深入、准确且有价值的答案,提高了生成文本的可靠性。此外,RAG模型具备检索库的更新机制,可以实现知识的即时更新,无需重新训练模型,这在及时性要求高的应用中占优势。

目前构建一个RAG并不是一个非常的事情。使用Langchain等成熟技术架构百十行代码就能构建一个Demo。那能不能利用目前的Rust生态构建一个简易的RAG。说干就干,本期和大家聊聊如果使用rust语言构建rag。

构建知识库

知识库构建主要是模型+向量库,为了保证所有系统中所有组件都使用rust构建,在限量数据库的选型上我们使用qdrant,纯rust构建的向量数据库。

知识库的构建最重要的步骤是embedding的过程。
过程如下:

  • 模型加载
  • 获取文本token
  • 通过模型获取文本的Embedding
    下面详细介绍每个过程细节及代码实现。

模型加载

以下代码用于加载模型和tokenizer


async fn build_model_and_tokenizer(model_config: &ConfigModel) -> Result<(BertModel, Tokenizer)> {
    let device = Device::new_cuda(0)?;
    let repo = Repo::with_revision(
        model_config.model_id.clone(),
        RepoType::Model,
        model_config.revision.clone(),
    );
    let (config_filename, tokenizer_filename, weights_filename) = {
        let api = ApiBuilder::new()    
            .build()?;
        let api = api.repo(repo);
        let config = api.get("config.json").await?;
        let tokenizer = api.get("tokenizer.json").await?;
        let weights = if model_config.use_pth {
            api.get("pytorch_model.bin").await?
        } else {
            api.get("model.safetensors").await?
        };
        (config, tokenizer, weights)A
    };
    let config = std::fs::read_to_string(config_filename)?;
    let mut config: Config = serde_json::from_str(&config)?;
    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

    let vb = if model_config.use_pth {
        VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
    } else {
        unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
    };
    if model_config.approximate_gelu {
        config.hidden_act = HiddenAct::GeluApproximate;
    }
    let model = BertModel::load(vb, &config)?;
    Ok((model, tokenizer))
}

模型和tokenizer是系统中频繁调用的部分,所以为了避免重复加载,通过OnceCell构建静态全局变量


pub static GLOBAL_EMBEDDING_MODEL: OnceCell> = OnceCell::const_new();

pub async fn init_model_and_tokenizer() -> Arc<(BertModel, Tokenizer)> {
    let config = get_config().unwrap();
    let (m, t) = build_model_and_tokenizer(&config.model).await.unwrap();
    Arc::new((m, t))
}

在系统启动时加载模型


GLOBAL_RUNTIME.block_on(async {
    log::info!("global runtime start!");
    // 加载model
    GLOBAL_EMBEDDING_MODEL
        .get_or_init(init_model_and_tokenizer)
        .await;
});

Embedding 过程主要由一下函数实现。


pub async fn embedding_setence(content: &str) -> Result>> {
    let m_t = GLOBAL_EMBEDDING_MODEL.get().unwrap();
    let tokens = m_t
        .1
        .encode(content, true)
        .map_err(E::msg)?
        .get_ids()
        .to_vec();
    let token_ids = Tensor::new(&tokens[..], &m_t.0.device)?.unsqueeze(0)?;
    let token_type_ids = token_ids.zeros_like()?;
    let sequence_output = m_t.0.forward(&token_ids, &token_type_ids)?;
    let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?;
    let embeddings = (sequence_output.sum(1)? / (n_tokens as f64))?;
    let embeddings = normalize_l2(&embeddings)?;
    let encodings = embeddings.to_vec2::()?;
    Ok(encodings)
}

函数通过tokenizer encode输入的文本,再使用模型embed token 获取一个三维的Tensor,最后归一化张量。

数据入库

知识库构建是将待检索文本向量化后存储到向量数据库的过程。
本次使用京东云文档作为原始文本,加工为以下格式。数据加工过程这里就不累述了。


{
    "content": "# 服务计费\n\n主机迁移服务自身为免费服务,但是迁移目标为云主机镜像时,迁移过程依赖系统自动创建的 中转资源的配合,这些资源中涉及部分付费资源,会产生相应费用。\n\n迁移过程涉及的中转资付费资源配置及计费说明如下(单个迁移任务):\n\n|  | 云主机 | 云硬盘 | 弹性公网IP |\n| --- | --- | --- | ------ |\n| 计费类型 | 按配置 | 按配置 | 按用量 |\n| 规格配置 | 2C4G (c.n2.large或c.n3.large或c.n1.large) | 系统盘:40G 通用型SSD 数据盘:通用型SSD,数量及容量取决于源服务器系统盘及数据盘情况 | 30Mbps |\n| 费用预估 | 云主机规格每小时价格\\*迁移时长 | 云硬盘规格每小时价格\\*迁移时长 | 弹性公网IP每小时保有费\\*迁移时长 仅使用弹性公网IP入方向流量,只涉及IP保有用,不涉及流量费用 |\n\n> 提示:\n>\n> * 迁移时长取决于源服务器迁数据量以及源服务器公网出方向带宽,公网连接顺畅且源服务器公网出方向带宽不低于22.5Mbps的情况下(主机迁移为单线程传输,京东云云主机在单流传输下实际带宽为带宽上限的75%左右),实际数据容量为5GB的磁盘迁移时长在30分钟左右。\n> * 中转实例实例绑定的安全组出方向默认拒绝所有流量,因此默认情况下降不会产生任何公网出方向收费流量,但此配置也影响了云主机部分监控指标的上报,如需要监控中转实例的全部监控数据,可自行调整安全组规则方向出方向443端口。",
    "title": "服务计费说明",
    "product": "云主机 CVM",
    "url": "https://docs.jdcloud.com/cn/virtual-machines/server-migration-service/billing"
}

入库完整代码如下:


use anyhow::Error as E;
use anyhow::Result;
use candle_core::Device;
use candle_core::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use hf_hub::{api::tokio::Api, Repo, RepoType};
use qdrant_client::qdrant::CollectionExistsRequest;
use qdrant_client::qdrant::CreateCollectionBuilder;
use qdrant_client::qdrant::DeleteCollection;
use qdrant_client::qdrant::Distance;
use qdrant_client::qdrant::UpsertPointsBuilder;
use qdrant_client::qdrant::VectorParamsBuilder;
use qdrant_client::Payload;
use qdrant_client::{
    qdrant::{
        CollectionOperationResponse, CreateCollection, PointStruct, PointsOperationResponse,
        UpsertPoints,
    },
    Qdrant,
};
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use std::fs;
use std::sync::Arc;
use tokenizers::Tokenizer;
use tokio::sync::OnceCell;
use uuid::Uuid;
use walkdir::WalkDir;

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Doc {
    pub content: String,
    pub title: String,
    pub product: String,
    pub url: String,
}

#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
pub struct ModelConfig {
    #[serde(default = "ModelConfig::model_id_default")]
    pub model_id: String,
    #[serde(default = "ModelConfig::revision_default")]
    pub revision: String,
    #[serde(default = "ModelConfig::use_pth_default")]
    pub use_pth: bool,
    #[serde(default = "ModelConfig::approximate_gelu_default")]
    pub approximate_gelu: bool,
}

impl Default for ModelConfig {
    fn default() -> Self {
        Self {
            model_id: Self::model_id_default(),
            revision: Self::revision_default(),
            use_pth: Self::use_pth_default(),
            approximate_gelu: Self::approximate_gelu_default(),
        }
    }
}

impl ModelConfig {
    fn model_id_default() -> String {
        "moka-ai/m3e-large".to_string()
    }
    fn revision_default() -> String {
        "main".to_string()
    }
    fn use_pth_default() -> bool {
        true
    }
    fn approximate_gelu_default() -> bool {
        false
    }
}

pub static GLOBAL_MODEL: OnceCell> = OnceCell::const_new();
pub static GLOBAL_TOKEN: OnceCell> = OnceCell::const_new();

pub async fn init_model() -> Arc {
    let config = ModelConfig::default();
    let (m, _) = build_model_and_tokenizer(&config).await.unwrap();
    Arc::new(m)
}

pub async fn init_tokenizer() -> Arc {
    let config = ModelConfig::default();
    let (_, t) = build_model_and_tokenizer(&config).await.unwrap();
    Arc::new(t)
}

async fn build_model_and_tokenizer(model_config: &ModelConfig) -> Result<(BertModel, Tokenizer)> {
    let device = Device::new_cuda(0)?;
    let repo = Repo::with_revision(
        model_config.model_id.clone(),
        RepoType::Model,
        model_config.revision.clone(),
    );
    let (config_filename, tokenizer_filename, weights_filename) = {
        let api = Api::new()?;
        let api = api.repo(repo);
        let config = api.get("config.json").await?;
        let tokenizer = api.get("tokenizer.json").await?;
        let weights = if model_config.use_pth {
            api.get("pytorch_model.bin").await?
        } else {
            api.get("model.safetensors").await?
        };
        (config, tokenizer, weights)
    };
    let config = std::fs::read_to_string(config_filename)?;
    let mut config: Config = serde_json::from_str(&config)?;
    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

    let vb = if model_config.use_pth {
        VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
    } else {
        unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
    };
    if model_config.approximate_gelu {
        config.hidden_act = HiddenAct::GeluApproximate;
    }
    let model = BertModel::load(vb, &config)?;
    Ok((model, tokenizer))
}

pub async fn embedding_setence(content: &str) -> Result>> {
    let m = GLOBAL_MODEL.get().unwrap();
    let t = GLOBAL_TOKEN.get().unwrap();
    let tokens = t.encode(content, true).map_err(E::msg)?.get_ids().to_vec();

    let token_ids = Tensor::new(&tokens[..], &m.device)?.unsqueeze(0)?;
    let token_type_ids = token_ids.zeros_like()?;

    let sequence_output = m.forward(&token_ids, &token_type_ids)?;
    let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?;
    let embeddings = (sequence_output.sum(1).unwrap() / (n_tokens as f64))?;
    let embeddings = normalize_l2(&embeddings).unwrap();
    let encodings = embeddings.to_vec2::()?;
    Ok(encodings)
}

pub fn normalize_l2(v: &Tensor) -> Result {
    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

pub struct QdrantClient {
    client: Qdrant,
}

impl QdrantClient {
    pub async fn create_collection(
        &self,
        request: impl Into,
    ) -> Result {
        let resp = self.client.create_collection(request).await?;
        Ok(resp)
    }

    pub async fn delete_collection(
        &self,
        request: impl Into,
    ) -> Result {
        let resp = self.client.delete_collection(request).await?;
        Ok(resp)
    }

    pub async fn collection_exists(
        &self,
        request: impl Into,
    ) -> Result {
        let resp = self.client.collection_exists(request).await?;
        Ok(resp)
    }

    pub async fn load_dir(&self, path: &str, collection_name: &str) {
        let mut points = vec![];
        for entry in WalkDir::new(path)
            .into_iter()
            .filter_map(Result::ok)
            .filter(|e| !e.file_type().is_dir() && e.file_name().to_str().is_some())
        {
            if let Some(p) = entry.path().to_str() {
                let id = Uuid::new_v4();
                let content = match fs::read_to_string(p) {
                    Ok(c) => c,
                    Err(_) => continue,
                };

                let doc = match from_str::(content.as_str()) {
                    Ok(d) => d,
                    Err(_) => continue,
                };
                let mut payload = Payload::new();
                payload.insert("content", doc.content);
                payload.insert("title", doc.title);
                payload.insert("product", doc.product);
                payload.insert("url", doc.url);
                let vector_contens = embedding_setence(content.as_str()).await.unwrap();
                let ps = PointStruct::new(id.to_string(), vector_contens[0].clone(), payload);
                points.push(ps);

                if points.len().eq(&100) {
                    let p = points.clone();
                    self.client
                        .upsert_points(UpsertPointsBuilder::new(collection_name, p).wait(true))
                        .await
                        .unwrap();
                    points.clear();
                    println!("batch finish");
                }
            }
        }

        if points.len().gt(&0) {
            self.client
                .upsert_points(UpsertPointsBuilder::new(collection_name, points).wait(true))
                .await
                .unwrap();
        }
    }
}

#[tokio::main]
async fn main() {
    // 加载模型
    GLOBAL_MODEL.get_or_init(init_model).await;
    GLOBAL_TOKEN.get_or_init(init_tokenizer).await;

    let collection_name = "default_collection";

    // The Rust client uses Qdrant's GRPC interface
    let qdrant = Qdrant::from_url("http://localhost:6334").build().unwrap();
    let qdrant_client = QdrantClient { client: qdrant };

    if !qdrant_client
        .collection_exists(collection_name)
        .await
        .unwrap()
    {
        qdrant_client
            .create_collection(
                CreateCollectionBuilder::new(collection_name)
                    .vectors_config(VectorParamsBuilder::new(1024, Distance::Dot)),
            )
            .await
            .unwrap();
    }

    qdrant_client
        .load_dir("/root/jd_docs", collection_name)
        .await;

    println!("{:?}", qdrant_client.client.health_check().await);
}

以上代码要完成的任务如下:

推理服务

推理服务使用 rust 构建的 mistral.rs

由于国内访问hf 并不方便所以先通过 https://hf-mirror.com/ 现将模型下载到本地。本次使用qwen模型


HF_ENDPOINT="https://hf-mirror.com"  huggingface-cli download --repo-type model --resume-download Qwen/Qwen2-7B --local-dir /root/Qwen2-7B

启动 mistralrs-server


git clone https://github.com/EricLBuehler/mistral.rs
cd mistral.rs
cargo run  --bin mistralrs-server  --features cuda -- --port 3333 plain -m /root/Qwen2-7B  -a qwen2

推理服务调用

mistral.rs 支持 Openai 的 api接口,使用 openai-api-rs调用即可。推理时间比较长 timeout 要设置长一些,若timeout 时间太短有可能不等返回结果就已经强制超时。


pub static GLOBAL_OPENAI_CLIENT: Lazy> = Lazy::new(|| {
    let mut client =
        OpenAIClient::new_with_endpoint("http://10.0.0.7:3333/v1".to_string(), "EMPTY".to_string());
    client.timeout = Some(30);
    Arc::new(client)
});

pub async fn inference(content: &str, max_len: i64) -> Result> {
    let req = ChatCompletionRequest::new(
        "".to_string(),
        vec![chat_completion::ChatCompletionMessage {
            role: chat_completion::MessageRole::user,
            content: chat_completion::Content::Text(content.to_string()),
            name: None,
            tool_calls: None,
            tool_call_id: None,
        }],
    )
    .max_tokens(max_len);

    let cr = GLOBAL_OPENAI_CLIENT.chat_completion(req).await?;
    Ok(cr.choices[0].message.content.clone())
}

将Retriever和推理服务集成


pub async fn answer(question: &str, max_len: i64) -> Result> {
    let retriver = retriever(question, 1).await?;
    let mut context = "".to_string();

    for sp in retriver.result {
        let payload = sp.payload;
        let product = payload.get("product").unwrap().to_string();
        let title = payload.get("title").unwrap().to_string();
        let content = payload.get("content").unwrap().to_string();
        context.push_str(&product);
        context.push_str(&title);
        context.push_str(&content);
    }

    let prompt = format!(
        "你是一个云技术专家, 使用以下检索到的Context回答问题。用中文回答问题。
        Question: {}
        Context: {}
        ",
        question, context
    );

    log::info!("{}", prompt);

    let req = ChatCompletionRequest::new(
        "".to_string(),
        vec![chat_completion::ChatCompletionMessage {
            role: chat_completion::MessageRole::user,
            content: chat_completion::Content::Text(prompt),
            name: None,
            tool_calls: None,
            tool_call_id: None,
        }],
    )
    .max_tokens(max_len);

    let cr = GLOBAL_OPENAI_CLIENT.chat_completion(req).await?;
    Ok(cr.choices[0].message.content.clone())
}

后记

完整工程地址[embedding_server]https://github.com/jiashiwen/embedding_server

后续工程问题,多卡推理,多机推理,推理加速

资源对比

  • GPU 型号

    
    
  • |=========================================+========================+======================|
    |   0  NVIDIA A30                     Off |   00000000:00:07.0 Off |                    0 |
    | N/A   30C    P0             29W /  165W |       0MiB /  24576MiB |      0%      Default |
    |                                         |                        |             Disabled |
    +-----------------------------------------+------------------------+----------------------+
    
  • Embedding 资源

    • m3e-large

      • vllm

        
        
  • +-----------------------------------------------------------------------------------------+
    | Processes:                                                                              |
    |  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
    |        ID   ID                                                               Usage      |
    |=========================================================================================|
    |    0   N/A  N/A    822789      C   ...iprojects/rag_demo/.venv/bin/python       1550MiB |
    +-----------------------------------------------------------------------------------------+
    
  • candle

    
    
      • +-----------------------------------------------------------------------------------------+
        | Processes:                                                                              |
        |  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
        |        ID   ID                                                               Usage      |
        |=========================================================================================|
        |    0   N/A  N/A    823261      C   target/debug/embedding_server                1484MiB |
        +-----------------------------------------------------------------------------------------+
        
  • 推理资源

    • Qwen1.5-1.8B-Chat

      • vllm

        
        
  • |=========================================================================================|
    |    0   N/A  N/A    822437      C   /usr/bin/python3                            20440MiB |
    +-----------------------------------------------------------------------------------------+
    
  • mistral.rs

    
    
    • |=========================================================================================|
      |    0   N/A  N/A    822174      C   target/debug/mistralrs-server               22134MiB |
      +-----------------------------------------------------------------------------------------+
      
  • Qwen2-7B

    • vllm 现存溢出

      
      
  • [rank0]: OutOfMemoryError: CUDA out of memory. Tried to allocate 9.25 GiB. GPU
    
  • mistral.rs

    
    
      • |=========================================================================================|
        |    0   N/A  N/A    656923      C   target/debug/mistralrs-server               22006MiB |
        +-----------------------------------------------------------------------------------------+
        

从实际情况来看,Embedding 模型再资源占用情况 rust candle框架使用显存略小些;推理模型Qwen1.5-1.8B-Chat,vllm 资源占用略小。Qwen2-7B vllm直接显存溢出。

大部分框架中使用 hf-hub 采用同步调用,不支持境内的mirror。动手改造

src/api/tokio.rs



impl ApiBuilder {
    /// Set endpoint example 'https://hf-mirror.com'
    pub fn with_endpoint(mut self, endpoint: &str) -> Self {
        self.endpoint = endpoint.to_string();
        self
    }
}

标签:RAG,use,--,await,pub,文盘,let,model,config
From: https://www.cnblogs.com/Jcloud/p/18451528

相关文章

  • 分布式锁
    单体应用可以使用synchronized或Lock来加锁,synchronized推荐使用类锁,也就是字节码锁,这样保证是全局唯一的,如果使用对象锁,要根据业务确定这个对象锁在这个业务中是唯一的。对于微服务架构下,单体应用锁就不合适了,每个服务多个节点部署,虚拟机都不是用一个,肯定保证不了唯一性LU......
  • day02_基本的DOS命令
    电脑常用快捷键常用快捷键快捷键作用CTRL+c复制CTRL+v粘贴CTRL+x剪切CTRL+z撤销CTRL+s保存alt+f4关闭窗口del删除shift+del强制删除Windows+r打开“运行”窗口windows+e打开“我的文档”ctrl+alt+d......
  • 【PHP代码审计】文件上传
    一、认识上传漏洞文件上传漏洞是指用户上传了一个可执行的脚本文件,并通过此脚本文件获得了执行服务器端命令的能力,这种攻击方式是最直接和有效的文件上传本身是没问题的,有问题的是文件上传后,服务器怎么处理,解析文件。通过服务器的处理逻辑做的不够安全,则会导致上传漏洞。二、上......
  • 关系数据库的范式(Normal Form)知识点
    第2题的内容是:单选题已知关系R(A,B,C,D)和R上的函数依赖集F={B→D,AB→C},候选码是(1),关系R属于(2)。选项A.1NFB.2NFC.3NFD.BCNF分析这道题目考察的是关系数据库的范式(NormalForm)知识点。范式的相关内容:第一范式(1NF):要求关系中的每个域都是原子性的,即每个字段都是不可分割的......
  • 【PHP代码审计】命令执行
    RemoteCodeExecute远程代码执行原理:应用程序在调用一些能够将字符串转换为代码的函数(例如php中的eval中),没有考虑用户是否控制这个字符串,将造成代码执行漏洞。函数eval()//把字符串作为PHP代码执行assert()//检查一个断言是否为FALSE,可用来执行代码preg_replace()//......
  • C# WebService返回参数为DataTable报错“XML文档有错误”
    该问题由于DataTable列存在自定义类型。解决该报错需要以下几步:1、自定义类型增加xml序列化2、由于C#从XML反序列化DataSet或DataTable时的默认限制,所以需要先把调用方的项目开放限制,如果是.netframework项目,需要在app.config中添加<configuration><runtime>......
  • 五款倾斜摄影与三维数据处理工具介绍:GISBox、Cesiumlab、OSGBLab、灵易智模、倾斜伴侣
    随着三维数据处理技术的广泛应用,尤其是在城市规划、地理信息系统(GIS)、工程监测等领域,处理倾斜摄影、三维建模以及大规模数据管理的需求日益增加。以下是五款我精心挑选的倾斜摄影和三维数据处理工具——GISBox、Cesiumlab、OSGBLab、灵易智模和倾斜伴侣,本文将详细介绍它们的功能、......
  • Vue.js 自定义事件命名
    什么是Vue.js自定义事件命名?在Vue.js中,自定义事件是一种允许组件之间进行通信的重要机制。通过自定义事件,我们可以在父组件和子组件之间传递数据,实现组件的解耦和复用。Vue.js中的事件命名可以使用驼峰命名法或短横线命名法。但是,Vue.js官方强烈建议使用短横线命名法来定义自定义......
  • 基于VITA57.1标准的8路SFP+光纤通道数据传输FMC子卡模块
     板卡概述FMC213是我司自主研制的一块基于FMC标准的8路万兆光纤子卡模块。该板卡符合VITA57.1标准,该板卡可以作为一个理想的IO模块耦合至FPGA前端,8路SFP+的高速串行信号直接连接至FMC(HPC)接口的高速串行总线上,与FPGA内部的万兆位级收发器(MGT)互联,SFP+模块支持业界标准的小型可插......
  • CRICOS Data Structures and AlgorithmsHash Tables
    DataStructuresandAlgorithmsHashTablesPage1of3CRICOSProvideCode:00301J Note:hashArraystoresthekey,valueandstate(used,free,orpreviously-used)ofeveryhashEntry.WemuststoreboththekeyandvaluesinceweneedtocheckhashArrayto......