首页 > 编程语言 >LangGraph 源码分析 | BaseTool 模板类

LangGraph 源码分析 | BaseTool 模板类

时间:2024-10-17 12:47:16浏览次数:10  
标签:LangGraph self results 源码 BaseTool str domains include Optional

文章目录

BaseTool 源码分析

BaseTool 是 LangChain 框架中定义 tools 的模板类

核心属性

  • name表示 tool 唯一名称的字符串(用于识别)
  • description:对如何 / 何时 / 为何使用该 tool 的描述,帮助模型决定什么时候调用该 tool
  • args_schema:验证工具输入参数的 Pydantic model 或 schema
  • return_direct:如果为True,则立即返回 tool 的输出
  • responcse_format:定义 tool 的响应格式
TavilySearchResults(BaseTool) 为例
name
name: str = "tavily_search_results_json"
description
    description: str = (
        "A search engine optimized for comprehensive, accurate, and trusted results. "
        "Useful for when you need to answer questions about current events. "
        "Input should be a search query."
    )
args_schema
class TavilyInput(BaseModel):
    """Input for the Tavily tool."""
    query: str = Field(description="search query to look up")


# 输入将遵循 TavilyInput 类中定义的架构规则
# 同时args_schema的值必须是BaseModel派生类
args_schema: Type[BaseModel] = TavilyInput
  • 按照TavilyInput的规则,如果输入没有提供query值,将抛出一个验证错误
  • Field函数用于向字段添加元数据(描述)
response_format
response_format: Literal["content_and_artifact"] = "content_and_artifact"
  • 使用 Literal 来确保某些值被限制为特定文字
@_LiteralSpecialForm
@_tp_cache(typed=True)
def Literal(self, *parameters):
    """Special typing form to define literal types (a.k.a. value types).

    This form can be used to indicate to type checkers that the corresponding
    variable or function parameter has a value equivalent to the provided
    literal (or one of several literals):

      def validate_simple(data: Any) -> Literal[True]:  # always returns True
          ...

      MODE = Literal['r', 'rb', 'w', 'wb']
      def open_helper(file: str, mode: MODE) -> str:
          ...

      open_helper('/some/path', 'r')  # Passes type check
      open_helper('/other/path', 'typo')  # Error in type checker

    Literal[...] cannot be subclassed. At runtime, an arbitrary value
    is allowed as type argument to Literal[...], but type checkers may
    impose restrictions.
    """
    # There is no '_type_check' call because arguments to Literal[...] are
    # values, not types.
    parameters = _flatten_literal_params(parameters)

    try:
        parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
    except TypeError:  # unhashable parameters
        pass

    return _LiteralGenericAlias(self, parameters)
查询选项属性
  • **max_results:返回的最大结果数量,默认为 5。
  • **search_depth:查询的深度,可以是 "basic""advanced",默认是 "advanced"
  • **include_domains:一个包含在结果中的域名列表(默认为空,即包含所有域名)。
  • exclude_domains:一个排除在结果之外的域名列表。
  • include_answer:是否在结果中包含简短答案,默认值为 False
  • include_raw_content:是否返回 HTML 原始内容的解析结果(默认关闭)。
  • include_images:是否在结果中包含相关图片,默认值为 False

需要子类实现的抽象方法

    @abstractmethod
    def _run(self, *args: Any, **kwargs: Any) -> Any:
        """Use the tool.

        Add run_manager: Optional[CallbackManagerForToolRun] = None
        to child implementations to enable tracing.
        """
TavilySearchResults(BaseTool) 为例
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper)  # type: ignore[arg-type]
  • api_wrapper 是一个 TavilySearchAPIWrapper 实例,用于封装 API 调用的细节
class TavilySearchAPIWrapper(BaseModel):
    """Wrapper for Tavily Search API."""

    tavily_api_key: SecretStr

    model_config = ConfigDict(
        extra="forbid",
    )

    @model_validator(mode="before")
    @classmethod
    def validate_environment(cls, values: Dict) -> Any:
        """Validate that api key and endpoint exists in environment."""
        tavily_api_key = get_from_dict_or_env(
            values, "tavily_api_key", "TAVILY_API_KEY"
        )
        values["tavily_api_key"] = tavily_api_key

        return values

    def raw_results(
        self,
        query: str,
        max_results: Optional[int] = 5,
        search_depth: Optional[str] = "advanced",
        include_domains: Optional[List[str]] = [],
        exclude_domains: Optional[List[str]] = [],
        include_answer: Optional[bool] = False,
        include_raw_content: Optional[bool] = False,
        include_images: Optional[bool] = False,
    ) -> Dict:
        params = {
            "api_key": self.tavily_api_key.get_secret_value(),
            "query": query,
            "max_results": max_results,
            "search_depth": search_depth,
            "include_domains": include_domains,
            "exclude_domains": exclude_domains,
            "include_answer": include_answer,
            "include_raw_content": include_raw_content,
            "include_images": include_images,
        }
        response = requests.post(
            # type: ignore
            f"{TAVILY_API_URL}/search",
            json=params,
        )
        response.raise_for_status()
        return response.json()

    def results(
        self,
        query: str,
        max_results: Optional[int] = 5,
        search_depth: Optional[str] = "advanced",
        include_domains: Optional[List[str]] = [],
        exclude_domains: Optional[List[str]] = [],
        include_answer: Optional[bool] = False,
        include_raw_content: Optional[bool] = False,
        include_images: Optional[bool] = False,
    ) -> List[Dict]:
        """Run query through Tavily Search and return metadata.

        Args:
            query: The query to search for.
            max_results: The maximum number of results to return.
            search_depth: The depth of the search. Can be "basic" or "advanced".
            include_domains: A list of domains to include in the search.
            exclude_domains: A list of domains to exclude from the search.
            include_answer: Whether to include the answer in the results.
            include_raw_content: Whether to include the raw content in the results.
            include_images: Whether to include images in the results.
        Returns:
            query: The query that was searched for.
            follow_up_questions: A list of follow up questions.
            response_time: The response time of the query.
            answer: The answer to the query.
            images: A list of images.
            results: A list of dictionaries containing the results:
                title: The title of the result.
                url: The url of the result.
                content: The content of the result.
                score: The score of the result.
                raw_content: The raw content of the result.
        """
        raw_search_results = self.raw_results(
            query,
            max_results=max_results,
            search_depth=search_depth,
            include_domains=include_domains,
            exclude_domains=exclude_domains,
            include_answer=include_answer,
            include_raw_content=include_raw_content,
            include_images=include_images,
        )
        return self.clean_results(raw_search_results["results"])

    async def raw_results_async(
        self,
        query: str,
        max_results: Optional[int] = 5,
        search_depth: Optional[str] = "advanced",
        include_domains: Optional[List[str]] = [],
        exclude_domains: Optional[List[str]] = [],
        include_answer: Optional[bool] = False,
        include_raw_content: Optional[bool] = False,
        include_images: Optional[bool] = False,
    ) -> Dict:
        """Get results from the Tavily Search API asynchronously."""

        # Function to perform the API call
        async def fetch() -> str:
            params = {
                "api_key": self.tavily_api_key.get_secret_value(),
                "query": query,
                "max_results": max_results,
                "search_depth": search_depth,
                "include_domains": include_domains,
                "exclude_domains": exclude_domains,
                "include_answer": include_answer,
                "include_raw_content": include_raw_content,
                "include_images": include_images,
            }
            async with aiohttp.ClientSession() as session:
                async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
                    if res.status == 200:
                        data = await res.text()
                        return data
                    else:
                        raise Exception(f"Error {res.status}: {res.reason}")

        results_json_str = await fetch()
        return json.loads(results_json_str)

    async def results_async(
        self,
        query: str,
        max_results: Optional[int] = 5,
        search_depth: Optional[str] = "advanced",
        include_domains: Optional[List[str]] = [],
        exclude_domains: Optional[List[str]] = [],
        include_answer: Optional[bool] = False,
        include_raw_content: Optional[bool] = False,
        include_images: Optional[bool] = False,
    ) -> List[Dict]:
        results_json = await self.raw_results_async(
            query=query,
            max_results=max_results,
            search_depth=search_depth,
            include_domains=include_domains,
            exclude_domains=exclude_domains,
            include_answer=include_answer,
            include_raw_content=include_raw_content,
            include_images=include_images,
        )
        return self.clean_results(results_json["results"])

    def clean_results(self, results: List[Dict]) -> List[Dict]:
        """Clean results from Tavily Search API."""
        clean_results = []
        for result in results:
            clean_results.append(
                {
                    "url": result["url"],
                    "content": result["content"],
                }
            )
        return clean_results
  • raw_results():同步调用 API。
  • raw_results_async():异步调用 API。
  • clean_results():清理和格式化查询结果。
    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> Tuple[Union[List[Dict[str, str]], str], Dict]:
        """Use the tool."""
        # TODO: remove try/except, should be handled by BaseTool
        try:
            raw_results = self.api_wrapper.raw_results(
                query,
                self.max_results,
                self.search_depth,
                self.include_domains,
                self.exclude_domains,
                self.include_answer,
                self.include_raw_content,
                self.include_images,
            )
        except Exception as e:
            return repr(e), {}
        return self.api_wrapper.clean_results(raw_results["results"]), raw_results
  • 传入查询参数,调用 TavilySearchAPIWrapper 来获取结果。
  • 如果查询失败,则返回错误信息。

核心方法

arun()run()的异步执行版本
    async def _arun(self, *args: Any, **kwargs: Any) -> Any:
        """Use the tool asynchronously.

        Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
        to child implementations to enable tracing.
        """
        if kwargs.get("run_manager") and signature(self._run).parameters.get(
            "run_manager"
        ):
            kwargs["run_manager"] = kwargs["run_manager"].get_sync()
        return await run_in_executor(None, self._run, *args, **kwargs)
  • 若具有run_manager参数,则转换为同步版本,然后使用默认执行器异步运行 self._run 方法
  • run_in_executor 是一个异步执行器,它允许你在不同的执行器中运行同步代码,而不会阻塞当前的事件循环
invoke()ainvoke()
def invoke(
    self,
    input: Union[str, dict, ToolCall],
    config: Optional[RunnableConfig] = None,
    **kwargs: Any,
) -> Any:
    tool_input, kwargs = _prep_run_args(input, config, **kwargs)
    return self.run(tool_input, **kwargs)

async def ainvoke(
    self,
    input: Union[str, dict, ToolCall],
    config: Optional[RunnableConfig] = None,
    **kwargs: Any,
) -> Any:
    tool_input, kwargs = _prep_run_args(input, config, **kwargs)
    return await self.arun(tool_input, **kwargs)
  • 充当执行工具逻辑的入口点
  • 准备输入参数,并在内部调用run()arun()

标签:LangGraph,self,results,源码,BaseTool,str,domains,include,Optional
From: https://blog.csdn.net/qq_45931691/article/details/143010255

相关文章

  • 【最新】1000个计算机毕业设计项目推荐(源码+数据库+论文)
     一、计算机毕业设计项目推荐(源码+论文+PPT)需要链接请私信我哦!或者在评论区打出来!2024年最新计算机毕业设计,本科,项目汇总!哈喽,大家好,大四的同学马上要开始做毕业设计了,大家做好准备了吗? 博主给大家详细整理了计算机毕业设计最新项目,对项目有任何疑问,都可以问博主哦!下面......
  • jsp东哈驾校管理系统的设计与实现dy35m(程序+源码+数据库+调试部署+开发环境)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表学员,教练,报名登记,退学登记,练车预约,考试预约,报考信息,报考反馈,成绩信息开题报告内容一、项目背景随着汽车保有量的不断增加,驾驶培训行业迎来了巨大的发......
  • jsp订餐管理系统的设计与实现3v4h1--(程序+源码+数据库+调试部署+开发环境)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表用户,厢房信息,厢房预约,菜品分类,特色美食,员工信息,营业统计开题报告内容一、项目背景随着餐饮行业的数字化转型,订餐管理系统成为提升餐厅运营效率、优化顾......
  • 网上纪念馆(源码+文档+部署+讲解)
    网上纪念馆是成品商业化项目,系统可基于源码二开。系统概述是一款为个人和家族提供线上祭祀、纪念服务的平台,涵盖了纪念馆创建、供奉记录、祭品管理、背景音乐设置等功能,让用户随时随地缅怀先人。详细功能介绍:纪念馆管理:支持创建不同类型的纪念馆(名人纪念馆、普通纪念馆......
  • jsp动物园管理系统的设计与实现zoejc程序+源码+数据库+调试部署+开发环境
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表饲养员,后勤人员,动物类型,动物信息,食谱定制,物资信息,物资入库,排班申请,排班申请2,工资信息开题报告内容一、研究背景与意义随着城市化进程的加快和公众对......
  • 电子病历系统(源码+文档+部署+讲解)
    电子病历系统是成品商业化项目,系统可基于源码二开。系统概述系统功能总结患者中心病历模版工作台:提供可自定义的病历模版,方便医生快速生成病历。预约管理:患者可在线预约就诊,系统自动生成预约记录。诊所管理:患者可查询就诊记录、查看诊所信息等。回访管理:系统可对......
  • 售票系统(源码+文档+部署+讲解)
    售票系统是成品商业化项目,系统可基于源码二开。系统概述票务管理系统是一款为游乐园量身定制的综合性管理平台,涵盖了从门票销售、检票管理到财务统计等全流程的业务,旨在提高运营效率,提升游客体验。详细功能介绍:票务管理:支持单票、套票销售,提供手工出票、检票管理等功......
  • 【开题报告】基于django+vue汽车维修服务系统(论文+源码)计算机毕业设计
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容研究背景随着汽车工业的快速发展和私家车保有量的急剧增加,汽车维修服务需求日益旺盛。传统的汽车维修服务模式存在信息不对称、服务流程繁琐、客户......
  • Java毕业设计-基于SSM框架的剧本杀预约系统项目实战(附源码+论文)
    大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。......
  • Java毕业设计-基于SSM框架的便民自行车管理系统项目实战(附源码+论文)
    大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。......