mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-26 15:24:48 +08:00
feat: RAG Integration (#238)
* feat: add rag provider and retriever * feat: retriever tool * feat: add retriever tool to the researcher node * feat: add rag http apis * feat: new message input supports resource mentions * feat: new message input component support resource mentions * refactor: need_web_search to need_search * chore: RAG integration docs * chore: change example api host * fix: user message color in dark mode * fix: mentions style * feat: add local_search_tool to researcher prompt * chore: research prompt * fix: ragflow page size and reporter with * docs: ragflow integration and add acknowledgment projects * chore: format
This commit is contained in:
74
src/tools/retriever.py
Normal file
74
src/tools/retriever.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import logging
|
||||
from typing import List, Optional, Type
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.rag import Document, Retriever, Resource, build_retriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrieverInput(BaseModel):
|
||||
keywords: str = Field(description="search keywords to look up")
|
||||
|
||||
|
||||
class RetrieverTool(BaseTool):
|
||||
name: str = "local_search_tool"
|
||||
description: str = (
|
||||
"Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
|
||||
)
|
||||
args_schema: Type[BaseModel] = RetrieverInput
|
||||
|
||||
retriever: Retriever = Field(default_factory=Retriever)
|
||||
resources: list[Resource] = Field(default_factory=list)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
keywords: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> list[Document]:
|
||||
logger.info(
|
||||
f"Retriever tool query: {keywords}", extra={"resources": self.resources}
|
||||
)
|
||||
documents = self.retriever.query_relevant_documents(keywords, self.resources)
|
||||
if not documents:
|
||||
return "No results found from the local knowledge base."
|
||||
return [doc.to_dict() for doc in documents]
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
keywords: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> list[Document]:
|
||||
return self._run(keywords, run_manager.get_sync())
|
||||
|
||||
|
||||
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
|
||||
if not resources:
|
||||
return None
|
||||
logger.info(f"create retriever tool: {SELECTED_RAG_PROVIDER}")
|
||||
retriever = build_retriever()
|
||||
|
||||
if not retriever:
|
||||
return None
|
||||
return RetrieverTool(retriever=retriever, resources=resources)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
resources = [
|
||||
Resource(
|
||||
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
|
||||
title="西游记",
|
||||
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
|
||||
)
|
||||
]
|
||||
retriever_tool = get_retriever_tool(resources)
|
||||
print(retriever_tool.name)
|
||||
print(retriever_tool.description)
|
||||
print(retriever_tool.args)
|
||||
print(retriever_tool.invoke("三打白骨精"))
|
||||
Reference in New Issue
Block a user