mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-28 16:24:47 +08:00
feat: RAG Integration (#238)
* feat: add rag provider and retriever * feat: retriever tool * feat: add retriever tool to the researcher node * feat: add rag http apis * feat: new message input supports resource mentions * feat: new message input component support resource mentions * refactor: need_web_search to need_search * chore: RAG integration docs * chore: change example api host * fix: user message color in dark mode * fix: mentions style * feat: add local_search_tool to researcher prompt * chore: research prompt * fix: ragflow page size and reporter with * docs: ragflow integration and add acknowledgment projects * chore: format
This commit is contained in:
@@ -5,6 +5,7 @@ import os
|
||||
|
||||
from .crawl import crawl_tool
|
||||
from .python_repl import python_repl_tool
|
||||
from .retriever import get_retriever_tool
|
||||
from .search import get_web_search_tool
|
||||
from .tts import VolcengineTTS
|
||||
|
||||
@@ -12,5 +13,6 @@ __all__ = [
|
||||
"crawl_tool",
|
||||
"python_repl_tool",
|
||||
"get_web_search_tool",
|
||||
"get_retriever_tool",
|
||||
"VolcengineTTS",
|
||||
]
|
||||
|
||||
74
src/tools/retriever.py
Normal file
74
src/tools/retriever.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import logging
|
||||
from typing import List, Optional, Type
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.rag import Document, Retriever, Resource, build_retriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrieverInput(BaseModel):
|
||||
keywords: str = Field(description="search keywords to look up")
|
||||
|
||||
|
||||
class RetrieverTool(BaseTool):
|
||||
name: str = "local_search_tool"
|
||||
description: str = (
|
||||
"Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
|
||||
)
|
||||
args_schema: Type[BaseModel] = RetrieverInput
|
||||
|
||||
retriever: Retriever = Field(default_factory=Retriever)
|
||||
resources: list[Resource] = Field(default_factory=list)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
keywords: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> list[Document]:
|
||||
logger.info(
|
||||
f"Retriever tool query: {keywords}", extra={"resources": self.resources}
|
||||
)
|
||||
documents = self.retriever.query_relevant_documents(keywords, self.resources)
|
||||
if not documents:
|
||||
return "No results found from the local knowledge base."
|
||||
return [doc.to_dict() for doc in documents]
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
keywords: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> list[Document]:
|
||||
return self._run(keywords, run_manager.get_sync())
|
||||
|
||||
|
||||
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
|
||||
if not resources:
|
||||
return None
|
||||
logger.info(f"create retriever tool: {SELECTED_RAG_PROVIDER}")
|
||||
retriever = build_retriever()
|
||||
|
||||
if not retriever:
|
||||
return None
|
||||
return RetrieverTool(retriever=retriever, resources=resources)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
resources = [
|
||||
Resource(
|
||||
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
|
||||
title="西游记",
|
||||
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
|
||||
)
|
||||
]
|
||||
retriever_tool = get_retriever_tool(resources)
|
||||
print(retriever_tool.name)
|
||||
print(retriever_tool.description)
|
||||
print(retriever_tool.args)
|
||||
print(retriever_tool.invoke("三打白骨精"))
|
||||
@@ -61,5 +61,9 @@ def get_web_search_tool(max_search_results: int):
|
||||
if __name__ == "__main__":
|
||||
results = LoggedDuckDuckGoSearch(
|
||||
name="web_search", max_results=3, output_format="list"
|
||||
).invoke("cute panda")
|
||||
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
)
|
||||
print(results.name)
|
||||
print(results.description)
|
||||
print(results.args)
|
||||
# .invoke("cute panda")
|
||||
# print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
||||
Reference in New Issue
Block a user