Files
deer-flow/src/tools/retriever.py
JeffJiang 462752b462 feat: RAG Integration (#238)
* feat: add rag provider and retriever

* feat: retriever tool

* feat: add retriever tool to the researcher node

* feat: add rag http apis

* feat: new message input supports resource mentions

* feat: new message input component support resource mentions

* refactor: need_web_search to need_search

* chore: RAG integration docs

* chore: change example api host

* fix: user message color in dark mode

* fix: mentions style

* feat: add local_search_tool to researcher prompt

* chore: research prompt

* fix: ragflow page size and reporter with

* docs: ragflow integration and add acknowledgment projects

* chore: format
2025-05-28 14:13:46 +08:00

75 lines
2.4 KiB
Python

import logging
from typing import List, Optional, Type
from langchain_core.tools import BaseTool
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from pydantic import BaseModel, Field
from src.config.tools import SELECTED_RAG_PROVIDER
from src.rag import Document, Retriever, Resource, build_retriever
logger = logging.getLogger(__name__)
class RetrieverInput(BaseModel):
keywords: str = Field(description="search keywords to look up")
class RetrieverTool(BaseTool):
name: str = "local_search_tool"
description: str = (
"Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
)
args_schema: Type[BaseModel] = RetrieverInput
retriever: Retriever = Field(default_factory=Retriever)
resources: list[Resource] = Field(default_factory=list)
def _run(
self,
keywords: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> list[Document]:
logger.info(
f"Retriever tool query: {keywords}", extra={"resources": self.resources}
)
documents = self.retriever.query_relevant_documents(keywords, self.resources)
if not documents:
return "No results found from the local knowledge base."
return [doc.to_dict() for doc in documents]
async def _arun(
self,
keywords: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> list[Document]:
return self._run(keywords, run_manager.get_sync())
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
if not resources:
return None
logger.info(f"create retriever tool: {SELECTED_RAG_PROVIDER}")
retriever = build_retriever()
if not retriever:
return None
return RetrieverTool(retriever=retriever, resources=resources)
if __name__ == "__main__":
resources = [
Resource(
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
title="西游记",
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
)
]
retriever_tool = get_retriever_tool(resources)
print(retriever_tool.name)
print(retriever_tool.description)
print(retriever_tool.args)
print(retriever_tool.invoke("三打白骨精"))