首页 > 其他分享 >graphrag api调用

graphrag api调用

时间:2024-07-30 16:53:10浏览次数:15  
标签:调用 graphrag root api store config data final dir

"""
参考:https://microsoft.github.io/graphrag/posts/get_started/
1. 初始化家目录:python -m graphrag.index --init --root ./ragtest
2. 初始化索引:python -m graphrag.index --root ./ragtest

脚本需要放置在ragtest目录下运行
"""

import os
import re
from pathlib import Path
from typing import cast, Union, Tuple

import pandas as pd

from graphrag.config import (
    GraphRagConfig,
    create_graphrag_config,
)
from graphrag.index.progress import PrintProgressReporter
from graphrag.query.input.loaders.dfs import (
    store_entity_semantic_embeddings,
)
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.query.factories import get_local_search_engine
from graphrag.query.indexer_adapters import (
    read_indexer_covariates,
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_reports,
    read_indexer_text_units,
)

reporter = PrintProgressReporter("")


class LocalSearchEngine:

    """
    根据官方代码适当调整:代码启动加载search_agent避免重复加载,对外仅暴露一个调用接口
    response_type 返回: Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report
    """
    def __init__(self, data_dir: Union[str, None], root_dir: Union[str, None]):
        self.data_dir, self.root_dir, self.config = self._configure_paths_and_settings(
            data_dir, root_dir
        )
        self.description_embedding_store = self._get_embedding_description_store()
        self.agent = self.search_agent(
            community_level=2, response_type="Single Paragraph"
        )

    def _configure_paths_and_settings(
        self, data_dir: Union[str, None], root_dir: Union[str, None]
    ) -> Tuple[str, Union[str, None], GraphRagConfig]:
        if data_dir is None and root_dir is None:
            msg = "Either data_dir or root_dir must be provided."
            raise ValueError(msg)
        if data_dir is None:
            data_dir = self._infer_data_dir(cast(str, root_dir))
        config = self._create_graphrag_config(root_dir, data_dir)
        return data_dir, root_dir, config

    @staticmethod
    def _infer_data_dir(root: str) -> str:
        output = Path(root) / "output"
        if output.exists():
            folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True)
            if folders:
                folder = folders[0]
                return str((folder / "artifacts").absolute())
        msg = f"Could not infer data directory from root={root}"
        raise ValueError(msg)

    def _create_graphrag_config(
        self, root: Union[str, None], data_dir: Union[str, None]
    ) -> GraphRagConfig:
        return self._read_config_parameters(cast(str, root or data_dir))

    @staticmethod
    def _read_config_parameters(root: str) -> GraphRagConfig:
        _root = Path(root)
        settings_yaml = _root / "settings.yaml"
        if not settings_yaml.exists():
            settings_yaml = _root / "settings.yml"
        settings_json = _root / "settings.json"

        if settings_yaml.exists():
            reporter.info(f"Reading settings from {settings_yaml}")
            with settings_yaml.open("rb") as file:
                import yaml

                data = yaml.safe_load(
                    file.read().decode(encoding="utf-8", errors="strict")
                )
                return create_graphrag_config(data, root)

        if settings_json.exists():
            reporter.info(f"Reading settings from {settings_json}")
            with settings_json.open("rb") as file:
                import json

                data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
                return create_graphrag_config(data, root)

        reporter.info("Reading settings from environment variables")
        return create_graphrag_config(root_dir=root)

    @staticmethod
    def _get_embedding_description_store(
        vector_store_type: str = VectorStoreType.LanceDB, config_args: dict = None
    ):
        if not config_args:
            config_args = {}

        config_args.update(
            {
                "collection_name": config_args.get(
                    "query_collection_name",
                    config_args.get("collection_name", "description_embedding"),
                ),
            }
        )

        description_embedding_store = VectorStoreFactory.get_vector_store(
            vector_store_type=vector_store_type, kwargs=config_args
        )

        description_embedding_store.connect(**config_args)
        return description_embedding_store

    def search_agent(self, community_level: int, response_type: str):
        """获取搜索引擎"""
        data_path = Path(self.data_dir)

        final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
        final_community_reports = pd.read_parquet(
            data_path / "create_final_community_reports.parquet"
        )
        final_text_units = pd.read_parquet(
            data_path / "create_final_text_units.parquet"
        )
        final_relationships = pd.read_parquet(
            data_path / "create_final_relationships.parquet"
        )
        final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
        final_covariates_path = data_path / "create_final_covariates.parquet"
        final_covariates = (
            pd.read_parquet(final_covariates_path)
            if final_covariates_path.exists()
            else None
        )

        vector_store_args = (
            self.config.embeddings.vector_store
            if self.config.embeddings.vector_store
            else {}
        )
        vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

        description_embedding_store = self._get_embedding_description_store(
            vector_store_type=vector_store_type,
            config_args=vector_store_args,
        )
        entities = read_indexer_entities(final_nodes, final_entities, community_level)
        store_entity_semantic_embeddings(
            entities=entities, vectorstore=description_embedding_store
        )
        covariates = (
            read_indexer_covariates(final_covariates)
            if final_covariates is not None
            else []
        )

        return get_local_search_engine(
            self.config,
            reports=read_indexer_reports(
                final_community_reports, final_nodes, community_level
            ),
            text_units=read_indexer_text_units(final_text_units),
            entities=entities,
            relationships=read_indexer_relationships(final_relationships),
            covariates={"claims": covariates},
            description_embedding_store=description_embedding_store,
            response_type=response_type,
        )

    def run_search(self, query: str):
        """
        搜索入口
        :param query: 问题
        :return:
        """
        result = self.agent.search(query=query)
        return self.remove_sources(result.response)

    @staticmethod
    def remove_sources(text):
        """
        使用正则表达式匹配 [Data: Sources (82, 14, 42, 98)] 这种格式的字符串
        :param text:
        :return:
        """
        cleaned_text = re.sub(r'\[Data: [^]]+\]', '', text)
        return cleaned_text


# Example usage
BASEDIR = os.path.dirname(__file__)  # Set your base directory path here

local_search_engine = LocalSearchEngine(data_dir=None, root_dir=BASEDIR)
if __name__ == '__main__':
    local_res = local_search_engine.run_search(
        query="如何添加设备",
    )
    print(local_res)

搜索方式有global跟loca两种。如果想通过api调用global,修改几个关键字就行。

标签:调用,graphrag,root,api,store,config,data,final,dir
From: https://www.cnblogs.com/52-qq/p/18332835

相关文章

  • 部署 Blender 脚本以用作 Web 服务器上的 api
    我在Nextjs中有一个网站和一个混合器脚本,它获取图像、纹理图像并将它们合并在一起,同时应用一些视觉效果(如深度)、渲染结果并将渲染结果的png图像返回到前端以供使用网站中的img标签。我制作了一个pythonFlask应用程序,安装了搅拌机,并制定了将搅拌机作为子进程运行的路线,......
  • 我用文心快码 Baidu Comate关联了自己的API文档,一键生成代码
    为了让大家快速掌握文心快码BaiduComate智能代码助手的高效使用技巧,我们为你准备了以下简易实操步骤,让你轻松地基于业务API文档生成符合业务规范的新代码。以某银行订单系统的支付业务为例:Step1:上传银行支付系统的API文档Step2:参考关联的API文档,BaiduComate智能代......
  • 如何检查多个依赖项中是否至少有一个在 Fastapi 中传递
    我有一个端点应该适用于两个不同的用户组,如果用户不属于任一组,我想给出正确的错误消息。对于这些组,我创建了也在其他端点中使用的依赖项:defis_teacher(email:str=Depends(get_email),db=Depends(get_db))->bool:teacher=...ifnotteacher:......
  • 阿里云设置跨域规则后调用OSS时仍然报No'Access-Control-Allow-Origin'的错误原因和解
    问题描述为了实现跨域访问,保证跨域数据传输的安全进行,在OSS控制台设置了跨域CORS规则后,通过SDK进行程序调用时报以下错误。No'Access-Control-Allow-Origin'headerispresentontherequestedresource问题原因出现跨域问题的原因如下:跨域CORS规则设置异常:未正确设......
  • 为什么在 CDS BETA 后出现 CDS API 格式错误?
    [对于上下文,我使用的是macOS和Python]安装cdsapi后,基本上按照官方网站的用户指南中的说明进行cdsapi设置:https://cds-beta.climate.copernicus.eu/how-to-api,并运行此示例代码进行数据访问,返回此错误{示例代码}importcdsapi客户端=cdsapi.Cli......
  • 雅虎财经 API 未检索数据
    `importstreamlitasstfromdatetimeimportdateimportyfinanceasyffromprophetimportProphetfromprophet.plotimportplot_plotlyfromplotlyimportgraph_objsasgoSTART="2014-01-01"TODAY=date.today().strftime("%Y-%m-%d"......
  • python身份证号码+姓名一致性核验、身份证号码真伪查询API集成
    身份证号码+姓名核验的方式,顾名思义是身份证二要素核验,一般情况下,身份证真伪查询需要上公安户籍系统查询,但此种方式仅适合个人查询,企业要想随时随地实现身份证实名认证的功能,便需要集成身份证实名认证接口功能。翔云人工智能开放平台提供身份证号实名认证接口,实时联网,上传身份证......
  • 如何将数字分配给返回的 python 数据列表,我可以调用这些数据来打印
    这里完全是菜鸟。我在网上搜索过,找不到我想要做的事情的答案。我的代码在这里:importbs4asbsimporturllib.requestsauce=urllib.request.urlopen('https://www.amazon.com/gp/rss/bestsellers/kitchen/289851/ref=zg_bs_289851_rsslink').read()soup=bs.Beautiful......
  • python API增值税发票四要素核验、数电票查验、医疗票查验
    长期以来,对发票进行高效的管理一直困扰着众多企业财务,手动录入效率慢、出错率高、纸质发票易丢失等。今天,翔云为广大企业提供了发票查验接口与财政票据查验接口服务,可针对增值税发票管理系统开具发票,医疗票据、非税收入等财政类票据进行真伪查验。翔云发票识别接口,使得企业财务无......
  • 纳米体育数据API电竞数据API:资料库数据包接口文档API示例⑥
    纳米体育数据的数据接口通过JSON拉流方式获取200多个国家的体育赛事实时数据或历史数据的编程接口,无请求次数限制,可按需购买,接口稳定高效;覆盖项目包括足球、篮球、网球、电子竞技、奥运等专题、数据内容。纳米数据API2.0版本包含http协议以及websocket协议,主要通过http获取数......