mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-05-02 10:10:44 +08:00
feat: RAG Integration (#238)
* feat: add rag provider and retriever * feat: retriever tool * feat: add retriever tool to the researcher node * feat: add rag http apis * feat: new message input supports resource mentions * feat: new message input component support resource mentions * refactor: need_web_search to need_search * chore: RAG integration docs * chore: change example api host * fix: user message color in dark mode * fix: mentions style * feat: add local_search_tool to researcher prompt * chore: research prompt * fix: ragflow page size and reporter with * docs: ragflow integration and add acknowledgment projects * chore: format
This commit is contained in:
@@ -2,16 +2,21 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.rag.retriever import Resource
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Configuration:
|
||||
"""The configurable fields."""
|
||||
|
||||
resources: list[Resource] = field(
|
||||
default_factory=list
|
||||
) # Resources to be used for the research
|
||||
max_plan_iterations: int = 1 # Maximum number of plan iterations
|
||||
max_step_num: int = 3 # Maximum number of steps in a plan
|
||||
max_search_results: int = 3 # Maximum number of search results
|
||||
|
||||
@@ -17,3 +17,10 @@ class SearchEngine(enum.Enum):
|
||||
|
||||
# Tool configuration
|
||||
SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
|
||||
|
||||
|
||||
class RAGProvider(enum.Enum):
|
||||
RAGFLOW = "ragflow"
|
||||
|
||||
|
||||
SELECTED_RAG_PROVIDER = os.getenv("RAG_PROVIDER")
|
||||
|
||||
@@ -17,6 +17,7 @@ from src.tools.search import LoggedTavilySearch
|
||||
from src.tools import (
|
||||
crawl_tool,
|
||||
get_web_search_tool,
|
||||
get_retriever_tool,
|
||||
python_repl_tool,
|
||||
)
|
||||
|
||||
@@ -206,10 +207,11 @@ def human_feedback_node(
|
||||
|
||||
|
||||
def coordinator_node(
|
||||
state: State,
|
||||
state: State, config: RunnableConfig
|
||||
) -> Command[Literal["planner", "background_investigator", "__end__"]]:
|
||||
"""Coordinator node that communicate with customers."""
|
||||
logger.info("Coordinator talking.")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
messages = apply_prompt_template("coordinator", state)
|
||||
response = (
|
||||
get_llm_by_type(AGENT_LLM_MAP["coordinator"])
|
||||
@@ -242,7 +244,7 @@ def coordinator_node(
|
||||
logger.debug(f"Coordinator response: {response}")
|
||||
|
||||
return Command(
|
||||
update={"locale": locale},
|
||||
update={"locale": locale, "resources": configurable.resources},
|
||||
goto=goto,
|
||||
)
|
||||
|
||||
@@ -326,14 +328,14 @@ async def _execute_agent_step(
|
||||
logger.warning("No unexecuted step found")
|
||||
return Command(goto="research_team")
|
||||
|
||||
logger.info(f"Executing step: {current_step.title}")
|
||||
logger.info(f"Executing step: {current_step.title}, agent: {agent_name}")
|
||||
|
||||
# Format completed steps information
|
||||
completed_steps_info = ""
|
||||
if completed_steps:
|
||||
completed_steps_info = "# Existing Research Findings\n\n"
|
||||
for i, step in enumerate(completed_steps):
|
||||
completed_steps_info += f"## Existing Finding {i+1}: {step.title}\n\n"
|
||||
completed_steps_info += f"## Existing Finding {i + 1}: {step.title}\n\n"
|
||||
completed_steps_info += f"<finding>\n{step.execution_res}\n</finding>\n\n"
|
||||
|
||||
# Prepare the input for the agent with completed steps info
|
||||
@@ -347,6 +349,19 @@ async def _execute_agent_step(
|
||||
|
||||
# Add citation reminder for researcher agent
|
||||
if agent_name == "researcher":
|
||||
if state.get("resources"):
|
||||
resources_info = "**The user mentioned the following resource files:**\n\n"
|
||||
for resource in state.get("resources"):
|
||||
resources_info += f"- {resource.title} ({resource.description})\n"
|
||||
|
||||
agent_input["messages"].append(
|
||||
HumanMessage(
|
||||
content=resources_info
|
||||
+ "\n\n"
|
||||
+ "You MUST use the **local_search_tool** to retrieve the information from the resource files.",
|
||||
)
|
||||
)
|
||||
|
||||
agent_input["messages"].append(
|
||||
HumanMessage(
|
||||
content="IMPORTANT: DO NOT include inline citations in the text. Instead, track all sources and include a References section at the end using link reference format. Include an empty line between each citation for better readability. Use this format for each reference:\n- [Source Title](URL)\n\n- [Another Source](URL)",
|
||||
@@ -377,6 +392,7 @@ async def _execute_agent_step(
|
||||
)
|
||||
recursion_limit = default_recursion_limit
|
||||
|
||||
logger.info(f"Agent input: {agent_input}")
|
||||
result = await agent.ainvoke(
|
||||
input=agent_input, config={"recursion_limit": recursion_limit}
|
||||
)
|
||||
@@ -468,11 +484,16 @@ async def researcher_node(
|
||||
"""Researcher node that do research"""
|
||||
logger.info("Researcher node is researching.")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
tools = [get_web_search_tool(configurable.max_search_results), crawl_tool]
|
||||
retriever_tool = get_retriever_tool(state.get("resources", []))
|
||||
if retriever_tool:
|
||||
tools.insert(0, retriever_tool)
|
||||
logger.info(f"Researcher tools: {tools}")
|
||||
return await _setup_and_execute_agent_step(
|
||||
state,
|
||||
config,
|
||||
"researcher",
|
||||
[get_web_search_tool(configurable.max_search_results), crawl_tool],
|
||||
tools,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import operator
|
||||
from typing import Annotated
|
||||
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
from src.prompts.planner_model import Plan
|
||||
from src.rag import Resource
|
||||
|
||||
|
||||
class State(MessagesState):
|
||||
@@ -15,6 +13,7 @@ class State(MessagesState):
|
||||
# Runtime Variables
|
||||
locale: str = "en-US"
|
||||
observations: list[str] = []
|
||||
resources: list[Resource] = []
|
||||
plan_iterations: int = 0
|
||||
current_plan: Plan | str = None
|
||||
final_report: str = ""
|
||||
|
||||
@@ -57,14 +57,15 @@ Before creating a detailed plan, assess if there is sufficient context to answer
|
||||
|
||||
Different types of steps have different web search requirements:
|
||||
|
||||
1. **Research Steps** (`need_web_search: true`):
|
||||
1. **Research Steps** (`need_search: true`):
|
||||
- Retrieve information from the file with the URL with `rag://` or `http://` prefix specified by the user
|
||||
- Gathering market data or industry trends
|
||||
- Finding historical information
|
||||
- Collecting competitor analysis
|
||||
- Researching current events or news
|
||||
- Finding statistical data or reports
|
||||
|
||||
2. **Data Processing Steps** (`need_web_search: false`):
|
||||
2. **Data Processing Steps** (`need_search: false`):
|
||||
- API calls and data extraction
|
||||
- Database queries
|
||||
- Raw data collection from existing sources
|
||||
@@ -74,10 +75,10 @@ Different types of steps have different web search requirements:
|
||||
## Exclusions
|
||||
|
||||
- **No Direct Calculations in Research Steps**:
|
||||
- Research steps should only gather data and information
|
||||
- All mathematical calculations must be handled by processing steps
|
||||
- Numerical analysis must be delegated to processing steps
|
||||
- Research steps focus on information gathering only
|
||||
- Research steps should only gather data and information
|
||||
- All mathematical calculations must be handled by processing steps
|
||||
- Numerical analysis must be delegated to processing steps
|
||||
- Research steps focus on information gathering only
|
||||
|
||||
## Analysis Framework
|
||||
|
||||
@@ -135,16 +136,16 @@ When planning information gathering, consider these key aspects and ensure COMPR
|
||||
- To begin with, repeat user's requirement in your own words as `thought`.
|
||||
- Rigorously assess if there is sufficient context to answer the question using the strict criteria above.
|
||||
- If context is sufficient:
|
||||
- Set `has_enough_context` to true
|
||||
- No need to create information gathering steps
|
||||
- Set `has_enough_context` to true
|
||||
- No need to create information gathering steps
|
||||
- If context is insufficient (default assumption):
|
||||
- Break down the required information using the Analysis Framework
|
||||
- Create NO MORE THAN {{ max_step_num }} focused and comprehensive steps that cover the most essential aspects
|
||||
- Ensure each step is substantial and covers related information categories
|
||||
- Prioritize breadth and depth within the {{ max_step_num }}-step constraint
|
||||
- For each step, carefully assess if web search is needed:
|
||||
- Research and external data gathering: Set `need_web_search: true`
|
||||
- Internal data processing: Set `need_web_search: false`
|
||||
- Break down the required information using the Analysis Framework
|
||||
- Create NO MORE THAN {{ max_step_num }} focused and comprehensive steps that cover the most essential aspects
|
||||
- Ensure each step is substantial and covers related information categories
|
||||
- Prioritize breadth and depth within the {{ max_step_num }}-step constraint
|
||||
- For each step, carefully assess if web search is needed:
|
||||
- Research and external data gathering: Set `need_search: true`
|
||||
- Internal data processing: Set `need_search: false`
|
||||
- Specify the exact data to be collected in step's `description`. Include a `note` if necessary.
|
||||
- Prioritize depth and volume of relevant information - limited information is not acceptable.
|
||||
- Use the same language as the user to generate the plan.
|
||||
@@ -156,10 +157,10 @@ Directly output the raw JSON format of `Plan` without "```json". The `Plan` inte
|
||||
|
||||
```ts
|
||||
interface Step {
|
||||
need_web_search: boolean; // Must be explicitly set for each step
|
||||
need_search: boolean; // Must be explicitly set for each step
|
||||
title: string;
|
||||
description: string; // Specify exactly what data to collect
|
||||
step_type: "research" | "processing"; // Indicates the nature of the step
|
||||
description: string; // Specify exactly what data to collect. If the user input contains a link, please retain the full Markdown format when necessary.
|
||||
step_type: "research" | "processing"; // Indicates the nature of the step
|
||||
}
|
||||
|
||||
interface Plan {
|
||||
@@ -167,7 +168,7 @@ interface Plan {
|
||||
has_enough_context: boolean;
|
||||
thought: string;
|
||||
title: string;
|
||||
steps: Step[]; // Research & Processing steps to get more context
|
||||
steps: Step[]; // Research & Processing steps to get more context
|
||||
}
|
||||
```
|
||||
|
||||
@@ -179,8 +180,8 @@ interface Plan {
|
||||
- Prioritize BOTH breadth (covering essential aspects) AND depth (detailed information on each aspect)
|
||||
- Never settle for minimal information - the goal is a comprehensive, detailed final report
|
||||
- Limited or insufficient information will lead to an inadequate final report
|
||||
- Carefully assess each step's web search requirement based on its nature:
|
||||
- Research steps (`need_web_search: true`) for gathering information
|
||||
- Processing steps (`need_web_search: false`) for calculations and data processing
|
||||
- Carefully assess each step's web search or retrieve from URL requirement based on its nature:
|
||||
- Research steps (`need_search: true`) for gathering information
|
||||
- Processing steps (`need_search: false`) for calculations and data processing
|
||||
- Default to gathering more information unless the strictest sufficient context criteria are met
|
||||
- Always use the language specified by the locale = **{{ locale }}**.
|
||||
- Always use the language specified by the locale = **{{ locale }}**.
|
||||
|
||||
@@ -13,9 +13,7 @@ class StepType(str, Enum):
|
||||
|
||||
|
||||
class Step(BaseModel):
|
||||
need_web_search: bool = Field(
|
||||
..., description="Must be explicitly set for each step"
|
||||
)
|
||||
need_search: bool = Field(..., description="Must be explicitly set for each step")
|
||||
title: str
|
||||
description: str = Field(..., description="Specify exactly what data to collect")
|
||||
step_type: StepType = Field(..., description="Indicates the nature of the step")
|
||||
@@ -47,7 +45,7 @@ class Plan(BaseModel):
|
||||
"title": "AI Market Research Plan",
|
||||
"steps": [
|
||||
{
|
||||
"need_web_search": True,
|
||||
"need_search": True,
|
||||
"title": "Current AI Market Analysis",
|
||||
"description": (
|
||||
"Collect data on market size, growth rates, major players, and investment trends in AI sector."
|
||||
|
||||
@@ -11,6 +11,9 @@ You are dedicated to conducting thorough investigations using search tools and p
|
||||
You have access to two types of tools:
|
||||
|
||||
1. **Built-in Tools**: These are always available:
|
||||
{% if resources %}
|
||||
- **local_search_tool**: For retrieving information from the local knowledge base when user mentioned in the messages.
|
||||
{% endif %}
|
||||
- **web_search_tool**: For performing web searches
|
||||
- **crawl_tool**: For reading content from URLs
|
||||
|
||||
@@ -34,7 +37,7 @@ You have access to two types of tools:
|
||||
3. **Plan the Solution**: Determine the best approach to solve the problem using the available tools.
|
||||
4. **Execute the Solution**:
|
||||
- Forget your previous knowledge, so you **should leverage the tools** to retrieve the information.
|
||||
- Use the **web_search_tool** or other suitable search tool to perform a search with the provided keywords.
|
||||
- Use the {% if resources %}**local_search_tool** or{% endif %}**web_search_tool** or other suitable search tool to perform a search with the provided keywords.
|
||||
- When the task includes time range requirements:
|
||||
- Incorporate appropriate time-based search parameters in your queries (e.g., "after:2020", "before:2023", or specific date ranges)
|
||||
- Ensure search results respect the specified time constraints.
|
||||
|
||||
5
src/rag/__init__.py
Normal file
5
src/rag/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .retriever import Retriever, Document, Resource
|
||||
from .ragflow import RAGFlowProvider
|
||||
from .builder import build_retriever
|
||||
|
||||
__all__ = [Retriever, Document, Resource, RAGFlowProvider, build_retriever]
|
||||
11
src/rag/builder.py
Normal file
11
src/rag/builder.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
|
||||
from src.rag.ragflow import RAGFlowProvider
|
||||
from src.rag.retriever import Retriever
|
||||
|
||||
|
||||
def build_retriever() -> Retriever | None:
|
||||
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
|
||||
return RAGFlowProvider()
|
||||
elif SELECTED_RAG_PROVIDER:
|
||||
raise ValueError(f"Unsupported RAG provider: {SELECTED_RAG_PROVIDER}")
|
||||
return None
|
||||
130
src/rag/ragflow.py
Normal file
130
src/rag/ragflow.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import os
|
||||
import requests
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class RAGFlowProvider(Retriever):
|
||||
"""
|
||||
RAGFlowProvider is a provider that uses RAGFlow to retrieve documents.
|
||||
"""
|
||||
|
||||
api_url: str
|
||||
api_key: str
|
||||
page_size: int = 10
|
||||
|
||||
def __init__(self):
|
||||
api_url = os.getenv("RAGFLOW_API_URL")
|
||||
if not api_url:
|
||||
raise ValueError("RAGFLOW_API_URL is not set")
|
||||
self.api_url = api_url
|
||||
|
||||
api_key = os.getenv("RAGFLOW_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("RAGFLOW_API_KEY is not set")
|
||||
self.api_key = api_key
|
||||
|
||||
page_size = os.getenv("RAGFLOW_PAGE_SIZE")
|
||||
if page_size:
|
||||
self.page_size = int(page_size)
|
||||
|
||||
def query_relevant_documents(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
dataset_ids: list[str] = []
|
||||
document_ids: list[str] = []
|
||||
|
||||
for resource in resources:
|
||||
dataset_id, document_id = parse_uri(resource.uri)
|
||||
dataset_ids.append(dataset_id)
|
||||
if document_id:
|
||||
document_ids.append(document_id)
|
||||
|
||||
payload = {
|
||||
"question": query,
|
||||
"dataset_ids": dataset_ids,
|
||||
"document_ids": document_ids,
|
||||
"page_size": self.page_size,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.api_url}/api/v1/retrieval", headers=headers, json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to query documents: {response.text}")
|
||||
|
||||
result = response.json()
|
||||
data = result.get("data", {})
|
||||
doc_aggs = data.get("doc_aggs", [])
|
||||
docs: dict[str, Document] = {
|
||||
doc.get("doc_id"): Document(
|
||||
id=doc.get("doc_id"),
|
||||
title=doc.get("doc_name"),
|
||||
chunks=[],
|
||||
)
|
||||
for doc in doc_aggs
|
||||
}
|
||||
|
||||
for chunk in data.get("chunks", []):
|
||||
doc = docs.get(chunk.get("document_id"))
|
||||
if doc:
|
||||
doc.chunks.append(
|
||||
Chunk(
|
||||
content=chunk.get("content"),
|
||||
similarity=chunk.get("similarity"),
|
||||
)
|
||||
)
|
||||
|
||||
return list(docs.values())
|
||||
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
params = {}
|
||||
if query:
|
||||
params["name"] = query
|
||||
|
||||
response = requests.get(
|
||||
f"{self.api_url}/api/v1/datasets", headers=headers, params=params
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to list resources: {response.text}")
|
||||
|
||||
result = response.json()
|
||||
resources = []
|
||||
|
||||
for item in result.get("data", []):
|
||||
item = Resource(
|
||||
uri=f"rag://dataset/{item.get('id')}",
|
||||
title=item.get("name", ""),
|
||||
description=item.get("description", ""),
|
||||
)
|
||||
resources.append(item)
|
||||
|
||||
return resources
|
||||
|
||||
|
||||
def parse_uri(uri: str) -> tuple[str, str]:
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme != "rag":
|
||||
raise ValueError(f"Invalid URI: {uri}")
|
||||
return parsed.path.split("/")[1], parsed.fragment
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uri = "rag://dataset/123#abc"
|
||||
parsed = urlparse(uri)
|
||||
print(parsed.scheme)
|
||||
print(parsed.netloc)
|
||||
print(parsed.path)
|
||||
print(parsed.fragment)
|
||||
77
src/rag/retriever.py
Normal file
77
src/rag/retriever.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import abc
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Chunk:
|
||||
content: str
|
||||
similarity: float
|
||||
|
||||
def __init__(self, content: str, similarity: float):
|
||||
self.content = content
|
||||
self.similarity = similarity
|
||||
|
||||
|
||||
class Document:
|
||||
"""
|
||||
Document is a class that represents a document.
|
||||
"""
|
||||
|
||||
id: str
|
||||
url: str | None = None
|
||||
title: str | None = None
|
||||
chunks: list[Chunk] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
url: str | None = None,
|
||||
title: str | None = None,
|
||||
chunks: list[Chunk] = [],
|
||||
):
|
||||
self.id = id
|
||||
self.url = url
|
||||
self.title = title
|
||||
self.chunks = chunks
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {
|
||||
"id": self.id,
|
||||
"content": "\n\n".join([chunk.content for chunk in self.chunks]),
|
||||
}
|
||||
if self.url:
|
||||
d["url"] = self.url
|
||||
if self.title:
|
||||
d["title"] = self.title
|
||||
return d
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""
|
||||
Resource is a class that represents a resource.
|
||||
"""
|
||||
|
||||
uri: str = Field(..., description="The URI of the resource")
|
||||
title: str = Field(..., description="The title of the resource")
|
||||
description: str | None = Field("", description="The description of the resource")
|
||||
|
||||
|
||||
class Retriever(abc.ABC):
|
||||
"""
|
||||
Define a RAG provider, which can be used to query documents and resources.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
List resources from the rag provider.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def query_relevant_documents(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Query relevant documents from the resources.
|
||||
"""
|
||||
pass
|
||||
@@ -5,19 +5,22 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, cast
|
||||
from typing import Annotated, List, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage, BaseMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.graph.builder import build_graph_with_memory
|
||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
||||
from src.prose.graph.builder import build_graph as build_prose_graph
|
||||
from src.rag.builder import build_retriever
|
||||
from src.rag.retriever import Resource
|
||||
from src.server.chat_request import (
|
||||
ChatMessage,
|
||||
ChatRequest,
|
||||
@@ -28,6 +31,11 @@ from src.server.chat_request import (
|
||||
)
|
||||
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
|
||||
from src.server.mcp_utils import load_mcp_tools
|
||||
from src.server.rag_request import (
|
||||
RAGConfigResponse,
|
||||
RAGResourceRequest,
|
||||
RAGResourcesResponse,
|
||||
)
|
||||
from src.tools import VolcengineTTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,6 +67,7 @@ async def chat_stream(request: ChatRequest):
|
||||
_astream_workflow_generator(
|
||||
request.model_dump()["messages"],
|
||||
thread_id,
|
||||
request.resources,
|
||||
request.max_plan_iterations,
|
||||
request.max_step_num,
|
||||
request.max_search_results,
|
||||
@@ -74,6 +83,7 @@ async def chat_stream(request: ChatRequest):
|
||||
async def _astream_workflow_generator(
|
||||
messages: List[ChatMessage],
|
||||
thread_id: str,
|
||||
resources: List[Resource],
|
||||
max_plan_iterations: int,
|
||||
max_step_num: int,
|
||||
max_search_results: int,
|
||||
@@ -101,6 +111,7 @@ async def _astream_workflow_generator(
|
||||
input_,
|
||||
config={
|
||||
"thread_id": thread_id,
|
||||
"resources": resources,
|
||||
"max_plan_iterations": max_plan_iterations,
|
||||
"max_step_num": max_step_num,
|
||||
"max_search_results": max_search_results,
|
||||
@@ -319,3 +330,18 @@ async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
||||
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise
|
||||
|
||||
|
||||
@app.get("/api/rag/config", response_model=RAGConfigResponse)
|
||||
async def rag_config():
|
||||
"""Get the config of the RAG."""
|
||||
return RAGConfigResponse(provider=SELECTED_RAG_PROVIDER)
|
||||
|
||||
|
||||
@app.get("/api/rag/resources", response_model=RAGResourcesResponse)
|
||||
async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
|
||||
"""Get the resources of the RAG."""
|
||||
retriever = build_retriever()
|
||||
if retriever:
|
||||
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
|
||||
return RAGResourcesResponse(resources=[])
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.rag.retriever import Resource
|
||||
|
||||
|
||||
class ContentItem(BaseModel):
|
||||
type: str = Field(..., description="The type of content (text, image, etc.)")
|
||||
@@ -28,6 +30,9 @@ class ChatRequest(BaseModel):
|
||||
messages: Optional[List[ChatMessage]] = Field(
|
||||
[], description="History of messages between the user and the assistant"
|
||||
)
|
||||
resources: Optional[List[Resource]] = Field(
|
||||
[], description="Resources to be used for the research"
|
||||
)
|
||||
debug: Optional[bool] = Field(False, description="Whether to enable debug logging")
|
||||
thread_id: Optional[str] = Field(
|
||||
"__default__", description="A specific conversation identifier"
|
||||
|
||||
25
src/server/rag_request.py
Normal file
25
src/server/rag_request.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.rag.retriever import Resource
|
||||
|
||||
|
||||
class RAGConfigResponse(BaseModel):
|
||||
"""Response model for RAG config."""
|
||||
|
||||
provider: str | None = Field(
|
||||
None, description="The provider of the RAG, default is ragflow"
|
||||
)
|
||||
|
||||
|
||||
class RAGResourceRequest(BaseModel):
|
||||
"""Request model for RAG resource."""
|
||||
|
||||
query: str | None = Field(
|
||||
None, description="The query of the resource need to be searched"
|
||||
)
|
||||
|
||||
|
||||
class RAGResourcesResponse(BaseModel):
|
||||
"""Response model for RAG resources."""
|
||||
|
||||
resources: list[Resource] = Field(..., description="The resources of the RAG")
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
|
||||
from .crawl import crawl_tool
|
||||
from .python_repl import python_repl_tool
|
||||
from .retriever import get_retriever_tool
|
||||
from .search import get_web_search_tool
|
||||
from .tts import VolcengineTTS
|
||||
|
||||
@@ -12,5 +13,6 @@ __all__ = [
|
||||
"crawl_tool",
|
||||
"python_repl_tool",
|
||||
"get_web_search_tool",
|
||||
"get_retriever_tool",
|
||||
"VolcengineTTS",
|
||||
]
|
||||
|
||||
74
src/tools/retriever.py
Normal file
74
src/tools/retriever.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import logging
|
||||
from typing import List, Optional, Type
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.rag import Document, Retriever, Resource, build_retriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrieverInput(BaseModel):
|
||||
keywords: str = Field(description="search keywords to look up")
|
||||
|
||||
|
||||
class RetrieverTool(BaseTool):
|
||||
name: str = "local_search_tool"
|
||||
description: str = (
|
||||
"Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
|
||||
)
|
||||
args_schema: Type[BaseModel] = RetrieverInput
|
||||
|
||||
retriever: Retriever = Field(default_factory=Retriever)
|
||||
resources: list[Resource] = Field(default_factory=list)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
keywords: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> list[Document]:
|
||||
logger.info(
|
||||
f"Retriever tool query: {keywords}", extra={"resources": self.resources}
|
||||
)
|
||||
documents = self.retriever.query_relevant_documents(keywords, self.resources)
|
||||
if not documents:
|
||||
return "No results found from the local knowledge base."
|
||||
return [doc.to_dict() for doc in documents]
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
keywords: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> list[Document]:
|
||||
return self._run(keywords, run_manager.get_sync())
|
||||
|
||||
|
||||
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
|
||||
if not resources:
|
||||
return None
|
||||
logger.info(f"create retriever tool: {SELECTED_RAG_PROVIDER}")
|
||||
retriever = build_retriever()
|
||||
|
||||
if not retriever:
|
||||
return None
|
||||
return RetrieverTool(retriever=retriever, resources=resources)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
resources = [
|
||||
Resource(
|
||||
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
|
||||
title="西游记",
|
||||
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
|
||||
)
|
||||
]
|
||||
retriever_tool = get_retriever_tool(resources)
|
||||
print(retriever_tool.name)
|
||||
print(retriever_tool.description)
|
||||
print(retriever_tool.args)
|
||||
print(retriever_tool.invoke("三打白骨精"))
|
||||
@@ -61,5 +61,9 @@ def get_web_search_tool(max_search_results: int):
|
||||
if __name__ == "__main__":
|
||||
results = LoggedDuckDuckGoSearch(
|
||||
name="web_search", max_results=3, output_format="list"
|
||||
).invoke("cute panda")
|
||||
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
)
|
||||
print(results.name)
|
||||
print(results.description)
|
||||
print(results.args)
|
||||
# .invoke("cute panda")
|
||||
# print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
||||
Reference in New Issue
Block a user