feat: RAG Integration (#238)

* feat: add rag provider and retriever

* feat: retriever tool

* feat: add retriever tool to the researcher node

* feat: add rag http apis

* feat: new message input supports resource mentions

* feat: new message input component support resource mentions

* refactor: need_web_search to need_search

* chore: RAG integration docs

* chore: change example api host

* fix: user message color in dark mode

* fix: mentions style

* feat: add local_search_tool to researcher prompt

* chore: research prompt

* fix: ragflow page size and reporter with

* docs: ragflow integration and add acknowledgment projects

* chore: format
This commit is contained in:
JeffJiang
2025-05-28 14:13:46 +08:00
committed by GitHub
parent 0565ab6d27
commit 462752b462
43 changed files with 1172 additions and 181 deletions

View File

@@ -2,16 +2,21 @@
# SPDX-License-Identifier: MIT
import os
from dataclasses import dataclass, fields
from dataclasses import dataclass, field, fields
from typing import Any, Optional
from langchain_core.runnables import RunnableConfig
from src.rag.retriever import Resource
@dataclass(kw_only=True)
class Configuration:
"""The configurable fields."""
resources: list[Resource] = field(
default_factory=list
) # Resources to be used for the research
max_plan_iterations: int = 1 # Maximum number of plan iterations
max_step_num: int = 3 # Maximum number of steps in a plan
max_search_results: int = 3 # Maximum number of search results

View File

@@ -17,3 +17,10 @@ class SearchEngine(enum.Enum):
# Tool configuration
SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
class RAGProvider(enum.Enum):
RAGFLOW = "ragflow"
SELECTED_RAG_PROVIDER = os.getenv("RAG_PROVIDER")

View File

@@ -17,6 +17,7 @@ from src.tools.search import LoggedTavilySearch
from src.tools import (
crawl_tool,
get_web_search_tool,
get_retriever_tool,
python_repl_tool,
)
@@ -206,10 +207,11 @@ def human_feedback_node(
def coordinator_node(
state: State,
state: State, config: RunnableConfig
) -> Command[Literal["planner", "background_investigator", "__end__"]]:
"""Coordinator node that communicate with customers."""
logger.info("Coordinator talking.")
configurable = Configuration.from_runnable_config(config)
messages = apply_prompt_template("coordinator", state)
response = (
get_llm_by_type(AGENT_LLM_MAP["coordinator"])
@@ -242,7 +244,7 @@ def coordinator_node(
logger.debug(f"Coordinator response: {response}")
return Command(
update={"locale": locale},
update={"locale": locale, "resources": configurable.resources},
goto=goto,
)
@@ -326,14 +328,14 @@ async def _execute_agent_step(
logger.warning("No unexecuted step found")
return Command(goto="research_team")
logger.info(f"Executing step: {current_step.title}")
logger.info(f"Executing step: {current_step.title}, agent: {agent_name}")
# Format completed steps information
completed_steps_info = ""
if completed_steps:
completed_steps_info = "# Existing Research Findings\n\n"
for i, step in enumerate(completed_steps):
completed_steps_info += f"## Existing Finding {i+1}: {step.title}\n\n"
completed_steps_info += f"## Existing Finding {i + 1}: {step.title}\n\n"
completed_steps_info += f"<finding>\n{step.execution_res}\n</finding>\n\n"
# Prepare the input for the agent with completed steps info
@@ -347,6 +349,19 @@ async def _execute_agent_step(
# Add citation reminder for researcher agent
if agent_name == "researcher":
if state.get("resources"):
resources_info = "**The user mentioned the following resource files:**\n\n"
for resource in state.get("resources"):
resources_info += f"- {resource.title} ({resource.description})\n"
agent_input["messages"].append(
HumanMessage(
content=resources_info
+ "\n\n"
+ "You MUST use the **local_search_tool** to retrieve the information from the resource files.",
)
)
agent_input["messages"].append(
HumanMessage(
content="IMPORTANT: DO NOT include inline citations in the text. Instead, track all sources and include a References section at the end using link reference format. Include an empty line between each citation for better readability. Use this format for each reference:\n- [Source Title](URL)\n\n- [Another Source](URL)",
@@ -377,6 +392,7 @@ async def _execute_agent_step(
)
recursion_limit = default_recursion_limit
logger.info(f"Agent input: {agent_input}")
result = await agent.ainvoke(
input=agent_input, config={"recursion_limit": recursion_limit}
)
@@ -468,11 +484,16 @@ async def researcher_node(
"""Researcher node that do research"""
logger.info("Researcher node is researching.")
configurable = Configuration.from_runnable_config(config)
tools = [get_web_search_tool(configurable.max_search_results), crawl_tool]
retriever_tool = get_retriever_tool(state.get("resources", []))
if retriever_tool:
tools.insert(0, retriever_tool)
logger.info(f"Researcher tools: {tools}")
return await _setup_and_execute_agent_step(
state,
config,
"researcher",
[get_web_search_tool(configurable.max_search_results), crawl_tool],
tools,
)

View File

@@ -1,12 +1,10 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import operator
from typing import Annotated
from langgraph.graph import MessagesState
from src.prompts.planner_model import Plan
from src.rag import Resource
class State(MessagesState):
@@ -15,6 +13,7 @@ class State(MessagesState):
# Runtime Variables
locale: str = "en-US"
observations: list[str] = []
resources: list[Resource] = []
plan_iterations: int = 0
current_plan: Plan | str = None
final_report: str = ""

View File

@@ -57,14 +57,15 @@ Before creating a detailed plan, assess if there is sufficient context to answer
Different types of steps have different web search requirements:
1. **Research Steps** (`need_web_search: true`):
1. **Research Steps** (`need_search: true`):
- Retrieve information from the file with the URL with `rag://` or `http://` prefix specified by the user
- Gathering market data or industry trends
- Finding historical information
- Collecting competitor analysis
- Researching current events or news
- Finding statistical data or reports
2. **Data Processing Steps** (`need_web_search: false`):
2. **Data Processing Steps** (`need_search: false`):
- API calls and data extraction
- Database queries
- Raw data collection from existing sources
@@ -74,10 +75,10 @@ Different types of steps have different web search requirements:
## Exclusions
- **No Direct Calculations in Research Steps**:
- Research steps should only gather data and information
- All mathematical calculations must be handled by processing steps
- Numerical analysis must be delegated to processing steps
- Research steps focus on information gathering only
- Research steps should only gather data and information
- All mathematical calculations must be handled by processing steps
- Numerical analysis must be delegated to processing steps
- Research steps focus on information gathering only
## Analysis Framework
@@ -135,16 +136,16 @@ When planning information gathering, consider these key aspects and ensure COMPR
- To begin with, repeat user's requirement in your own words as `thought`.
- Rigorously assess if there is sufficient context to answer the question using the strict criteria above.
- If context is sufficient:
- Set `has_enough_context` to true
- No need to create information gathering steps
- Set `has_enough_context` to true
- No need to create information gathering steps
- If context is insufficient (default assumption):
- Break down the required information using the Analysis Framework
- Create NO MORE THAN {{ max_step_num }} focused and comprehensive steps that cover the most essential aspects
- Ensure each step is substantial and covers related information categories
- Prioritize breadth and depth within the {{ max_step_num }}-step constraint
- For each step, carefully assess if web search is needed:
- Research and external data gathering: Set `need_web_search: true`
- Internal data processing: Set `need_web_search: false`
- Break down the required information using the Analysis Framework
- Create NO MORE THAN {{ max_step_num }} focused and comprehensive steps that cover the most essential aspects
- Ensure each step is substantial and covers related information categories
- Prioritize breadth and depth within the {{ max_step_num }}-step constraint
- For each step, carefully assess if web search is needed:
- Research and external data gathering: Set `need_search: true`
- Internal data processing: Set `need_search: false`
- Specify the exact data to be collected in step's `description`. Include a `note` if necessary.
- Prioritize depth and volume of relevant information - limited information is not acceptable.
- Use the same language as the user to generate the plan.
@@ -156,10 +157,10 @@ Directly output the raw JSON format of `Plan` without "```json". The `Plan` inte
```ts
interface Step {
need_web_search: boolean; // Must be explicitly set for each step
need_search: boolean; // Must be explicitly set for each step
title: string;
description: string; // Specify exactly what data to collect
step_type: "research" | "processing"; // Indicates the nature of the step
description: string; // Specify exactly what data to collect. If the user input contains a link, please retain the full Markdown format when necessary.
step_type: "research" | "processing"; // Indicates the nature of the step
}
interface Plan {
@@ -167,7 +168,7 @@ interface Plan {
has_enough_context: boolean;
thought: string;
title: string;
steps: Step[]; // Research & Processing steps to get more context
steps: Step[]; // Research & Processing steps to get more context
}
```
@@ -179,8 +180,8 @@ interface Plan {
- Prioritize BOTH breadth (covering essential aspects) AND depth (detailed information on each aspect)
- Never settle for minimal information - the goal is a comprehensive, detailed final report
- Limited or insufficient information will lead to an inadequate final report
- Carefully assess each step's web search requirement based on its nature:
- Research steps (`need_web_search: true`) for gathering information
- Processing steps (`need_web_search: false`) for calculations and data processing
- Carefully assess each step's web search or retrieve from URL requirement based on its nature:
- Research steps (`need_search: true`) for gathering information
- Processing steps (`need_search: false`) for calculations and data processing
- Default to gathering more information unless the strictest sufficient context criteria are met
- Always use the language specified by the locale = **{{ locale }}**.
- Always use the language specified by the locale = **{{ locale }}**.

View File

@@ -13,9 +13,7 @@ class StepType(str, Enum):
class Step(BaseModel):
need_web_search: bool = Field(
..., description="Must be explicitly set for each step"
)
need_search: bool = Field(..., description="Must be explicitly set for each step")
title: str
description: str = Field(..., description="Specify exactly what data to collect")
step_type: StepType = Field(..., description="Indicates the nature of the step")
@@ -47,7 +45,7 @@ class Plan(BaseModel):
"title": "AI Market Research Plan",
"steps": [
{
"need_web_search": True,
"need_search": True,
"title": "Current AI Market Analysis",
"description": (
"Collect data on market size, growth rates, major players, and investment trends in AI sector."

View File

@@ -11,6 +11,9 @@ You are dedicated to conducting thorough investigations using search tools and p
You have access to two types of tools:
1. **Built-in Tools**: These are always available:
{% if resources %}
- **local_search_tool**: For retrieving information from the local knowledge base when user mentioned in the messages.
{% endif %}
- **web_search_tool**: For performing web searches
- **crawl_tool**: For reading content from URLs
@@ -34,7 +37,7 @@ You have access to two types of tools:
3. **Plan the Solution**: Determine the best approach to solve the problem using the available tools.
4. **Execute the Solution**:
- Forget your previous knowledge, so you **should leverage the tools** to retrieve the information.
- Use the **web_search_tool** or other suitable search tool to perform a search with the provided keywords.
- Use the {% if resources %}**local_search_tool** or{% endif %}**web_search_tool** or other suitable search tool to perform a search with the provided keywords.
- When the task includes time range requirements:
- Incorporate appropriate time-based search parameters in your queries (e.g., "after:2020", "before:2023", or specific date ranges)
- Ensure search results respect the specified time constraints.

5
src/rag/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .retriever import Retriever, Document, Resource
from .ragflow import RAGFlowProvider
from .builder import build_retriever
__all__ = [Retriever, Document, Resource, RAGFlowProvider, build_retriever]

11
src/rag/builder.py Normal file
View File

@@ -0,0 +1,11 @@
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
from src.rag.ragflow import RAGFlowProvider
from src.rag.retriever import Retriever
def build_retriever() -> Retriever | None:
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
return RAGFlowProvider()
elif SELECTED_RAG_PROVIDER:
raise ValueError(f"Unsupported RAG provider: {SELECTED_RAG_PROVIDER}")
return None

130
src/rag/ragflow.py Normal file
View File

@@ -0,0 +1,130 @@
import os
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
from urllib.parse import urlparse
class RAGFlowProvider(Retriever):
"""
RAGFlowProvider is a provider that uses RAGFlow to retrieve documents.
"""
api_url: str
api_key: str
page_size: int = 10
def __init__(self):
api_url = os.getenv("RAGFLOW_API_URL")
if not api_url:
raise ValueError("RAGFLOW_API_URL is not set")
self.api_url = api_url
api_key = os.getenv("RAGFLOW_API_KEY")
if not api_key:
raise ValueError("RAGFLOW_API_KEY is not set")
self.api_key = api_key
page_size = os.getenv("RAGFLOW_PAGE_SIZE")
if page_size:
self.page_size = int(page_size)
def query_relevant_documents(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
dataset_ids: list[str] = []
document_ids: list[str] = []
for resource in resources:
dataset_id, document_id = parse_uri(resource.uri)
dataset_ids.append(dataset_id)
if document_id:
document_ids.append(document_id)
payload = {
"question": query,
"dataset_ids": dataset_ids,
"document_ids": document_ids,
"page_size": self.page_size,
}
response = requests.post(
f"{self.api_url}/api/v1/retrieval", headers=headers, json=payload
)
if response.status_code != 200:
raise Exception(f"Failed to query documents: {response.text}")
result = response.json()
data = result.get("data", {})
doc_aggs = data.get("doc_aggs", [])
docs: dict[str, Document] = {
doc.get("doc_id"): Document(
id=doc.get("doc_id"),
title=doc.get("doc_name"),
chunks=[],
)
for doc in doc_aggs
}
for chunk in data.get("chunks", []):
doc = docs.get(chunk.get("document_id"))
if doc:
doc.chunks.append(
Chunk(
content=chunk.get("content"),
similarity=chunk.get("similarity"),
)
)
return list(docs.values())
def list_resources(self, query: str | None = None) -> list[Resource]:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
params = {}
if query:
params["name"] = query
response = requests.get(
f"{self.api_url}/api/v1/datasets", headers=headers, params=params
)
if response.status_code != 200:
raise Exception(f"Failed to list resources: {response.text}")
result = response.json()
resources = []
for item in result.get("data", []):
item = Resource(
uri=f"rag://dataset/{item.get('id')}",
title=item.get("name", ""),
description=item.get("description", ""),
)
resources.append(item)
return resources
def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri)
if parsed.scheme != "rag":
raise ValueError(f"Invalid URI: {uri}")
return parsed.path.split("/")[1], parsed.fragment
if __name__ == "__main__":
uri = "rag://dataset/123#abc"
parsed = urlparse(uri)
print(parsed.scheme)
print(parsed.netloc)
print(parsed.path)
print(parsed.fragment)

77
src/rag/retriever.py Normal file
View File

@@ -0,0 +1,77 @@
import abc
from pydantic import BaseModel, Field
class Chunk:
content: str
similarity: float
def __init__(self, content: str, similarity: float):
self.content = content
self.similarity = similarity
class Document:
"""
Document is a class that represents a document.
"""
id: str
url: str | None = None
title: str | None = None
chunks: list[Chunk] = []
def __init__(
self,
id: str,
url: str | None = None,
title: str | None = None,
chunks: list[Chunk] = [],
):
self.id = id
self.url = url
self.title = title
self.chunks = chunks
def to_dict(self) -> dict:
d = {
"id": self.id,
"content": "\n\n".join([chunk.content for chunk in self.chunks]),
}
if self.url:
d["url"] = self.url
if self.title:
d["title"] = self.title
return d
class Resource(BaseModel):
"""
Resource is a class that represents a resource.
"""
uri: str = Field(..., description="The URI of the resource")
title: str = Field(..., description="The title of the resource")
description: str | None = Field("", description="The description of the resource")
class Retriever(abc.ABC):
"""
Define a RAG provider, which can be used to query documents and resources.
"""
@abc.abstractmethod
def list_resources(self, query: str | None = None) -> list[Resource]:
"""
List resources from the rag provider.
"""
pass
@abc.abstractmethod
def query_relevant_documents(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Query relevant documents from the resources.
"""
pass

View File

@@ -5,19 +5,22 @@ import base64
import json
import logging
import os
from typing import List, cast
from typing import Annotated, List, cast
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage, BaseMessage
from langgraph.types import Command
from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory
from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph
from src.prose.graph.builder import build_graph as build_prose_graph
from src.rag.builder import build_retriever
from src.rag.retriever import Resource
from src.server.chat_request import (
ChatMessage,
ChatRequest,
@@ -28,6 +31,11 @@ from src.server.chat_request import (
)
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
from src.server.mcp_utils import load_mcp_tools
from src.server.rag_request import (
RAGConfigResponse,
RAGResourceRequest,
RAGResourcesResponse,
)
from src.tools import VolcengineTTS
logger = logging.getLogger(__name__)
@@ -59,6 +67,7 @@ async def chat_stream(request: ChatRequest):
_astream_workflow_generator(
request.model_dump()["messages"],
thread_id,
request.resources,
request.max_plan_iterations,
request.max_step_num,
request.max_search_results,
@@ -74,6 +83,7 @@ async def chat_stream(request: ChatRequest):
async def _astream_workflow_generator(
messages: List[ChatMessage],
thread_id: str,
resources: List[Resource],
max_plan_iterations: int,
max_step_num: int,
max_search_results: int,
@@ -101,6 +111,7 @@ async def _astream_workflow_generator(
input_,
config={
"thread_id": thread_id,
"resources": resources,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
"max_search_results": max_search_results,
@@ -319,3 +330,18 @@ async def mcp_server_metadata(request: MCPServerMetadataRequest):
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
raise
@app.get("/api/rag/config", response_model=RAGConfigResponse)
async def rag_config():
"""Get the config of the RAG."""
return RAGConfigResponse(provider=SELECTED_RAG_PROVIDER)
@app.get("/api/rag/resources", response_model=RAGResourcesResponse)
async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
"""Get the resources of the RAG."""
retriever = build_retriever()
if retriever:
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
return RAGResourcesResponse(resources=[])

View File

@@ -5,6 +5,8 @@ from typing import List, Optional, Union
from pydantic import BaseModel, Field
from src.rag.retriever import Resource
class ContentItem(BaseModel):
type: str = Field(..., description="The type of content (text, image, etc.)")
@@ -28,6 +30,9 @@ class ChatRequest(BaseModel):
messages: Optional[List[ChatMessage]] = Field(
[], description="History of messages between the user and the assistant"
)
resources: Optional[List[Resource]] = Field(
[], description="Resources to be used for the research"
)
debug: Optional[bool] = Field(False, description="Whether to enable debug logging")
thread_id: Optional[str] = Field(
"__default__", description="A specific conversation identifier"

25
src/server/rag_request.py Normal file
View File

@@ -0,0 +1,25 @@
from pydantic import BaseModel, Field
from src.rag.retriever import Resource
class RAGConfigResponse(BaseModel):
"""Response model for RAG config."""
provider: str | None = Field(
None, description="The provider of the RAG, default is ragflow"
)
class RAGResourceRequest(BaseModel):
"""Request model for RAG resource."""
query: str | None = Field(
None, description="The query of the resource need to be searched"
)
class RAGResourcesResponse(BaseModel):
"""Response model for RAG resources."""
resources: list[Resource] = Field(..., description="The resources of the RAG")

View File

@@ -5,6 +5,7 @@ import os
from .crawl import crawl_tool
from .python_repl import python_repl_tool
from .retriever import get_retriever_tool
from .search import get_web_search_tool
from .tts import VolcengineTTS
@@ -12,5 +13,6 @@ __all__ = [
"crawl_tool",
"python_repl_tool",
"get_web_search_tool",
"get_retriever_tool",
"VolcengineTTS",
]

74
src/tools/retriever.py Normal file
View File

@@ -0,0 +1,74 @@
import logging
from typing import List, Optional, Type
from langchain_core.tools import BaseTool
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from pydantic import BaseModel, Field
from src.config.tools import SELECTED_RAG_PROVIDER
from src.rag import Document, Retriever, Resource, build_retriever
logger = logging.getLogger(__name__)
class RetrieverInput(BaseModel):
keywords: str = Field(description="search keywords to look up")
class RetrieverTool(BaseTool):
name: str = "local_search_tool"
description: str = (
"Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
)
args_schema: Type[BaseModel] = RetrieverInput
retriever: Retriever = Field(default_factory=Retriever)
resources: list[Resource] = Field(default_factory=list)
def _run(
self,
keywords: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> list[Document]:
logger.info(
f"Retriever tool query: {keywords}", extra={"resources": self.resources}
)
documents = self.retriever.query_relevant_documents(keywords, self.resources)
if not documents:
return "No results found from the local knowledge base."
return [doc.to_dict() for doc in documents]
async def _arun(
self,
keywords: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> list[Document]:
return self._run(keywords, run_manager.get_sync())
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
if not resources:
return None
logger.info(f"create retriever tool: {SELECTED_RAG_PROVIDER}")
retriever = build_retriever()
if not retriever:
return None
return RetrieverTool(retriever=retriever, resources=resources)
if __name__ == "__main__":
resources = [
Resource(
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
title="西游记",
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
)
]
retriever_tool = get_retriever_tool(resources)
print(retriever_tool.name)
print(retriever_tool.description)
print(retriever_tool.args)
print(retriever_tool.invoke("三打白骨精"))

View File

@@ -61,5 +61,9 @@ def get_web_search_tool(max_search_results: int):
if __name__ == "__main__":
results = LoggedDuckDuckGoSearch(
name="web_search", max_results=3, output_format="list"
).invoke("cute panda")
print(json.dumps(results, indent=2, ensure_ascii=False))
)
print(results.name)
print(results.description)
print(results.args)
# .invoke("cute panda")
# print(json.dumps(results, indent=2, ensure_ascii=False))