mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-23 14:14:46 +08:00
* feat: local search tool call result display * chore: add file copyright * fix: miss edit plan interrupt feedback * feat: disable pasting html into input box
76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import logging
|
|
from typing import List, Optional, Type
|
|
from langchain_core.tools import BaseTool
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForToolRun,
|
|
CallbackManagerForToolRun,
|
|
)
|
|
from pydantic import BaseModel, Field
|
|
|
|
from src.config.tools import SELECTED_RAG_PROVIDER
|
|
from src.rag import Document, Retriever, Resource, build_retriever
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RetrieverInput(BaseModel):
|
|
keywords: str = Field(description="search keywords to look up")
|
|
|
|
|
|
class RetrieverTool(BaseTool):
|
|
name: str = "local_search_tool"
|
|
description: str = "Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
|
|
args_schema: Type[BaseModel] = RetrieverInput
|
|
|
|
retriever: Retriever = Field(default_factory=Retriever)
|
|
resources: list[Resource] = Field(default_factory=list)
|
|
|
|
def _run(
|
|
self,
|
|
keywords: str,
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
) -> list[Document]:
|
|
logger.info(
|
|
f"Retriever tool query: {keywords}", extra={"resources": self.resources}
|
|
)
|
|
documents = self.retriever.query_relevant_documents(keywords, self.resources)
|
|
if not documents:
|
|
return "No results found from the local knowledge base."
|
|
return [doc.to_dict() for doc in documents]
|
|
|
|
async def _arun(
|
|
self,
|
|
keywords: str,
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
) -> list[Document]:
|
|
return self._run(keywords, run_manager.get_sync())
|
|
|
|
|
|
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
|
|
if not resources:
|
|
return None
|
|
logger.info(f"create retriever tool: {SELECTED_RAG_PROVIDER}")
|
|
retriever = build_retriever()
|
|
|
|
if not retriever:
|
|
return None
|
|
return RetrieverTool(retriever=retriever, resources=resources)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
resources = [
|
|
Resource(
|
|
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
|
|
title="西游记",
|
|
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
|
|
)
|
|
]
|
|
retriever_tool = get_retriever_tool(resources)
|
|
print(retriever_tool.name)
|
|
print(retriever_tool.description)
|
|
print(retriever_tool.args)
|
|
print(retriever_tool.invoke("三打白骨精"))
|