feat: RAG Integration (#238)

* feat: add rag provider and retriever

* feat: retriever tool

* feat: add retriever tool to the researcher node

* feat: add rag http apis

* feat: new message input supports resource mentions

* feat: new message input component support resource mentions

* refactor: need_web_search to need_search

* chore: RAG integration docs

* chore: change example api host

* fix: user message color in dark mode

* fix: mentions style

* feat: add local_search_tool to researcher prompt

* chore: research prompt

* fix: ragflow page size and reporter with

* docs: ragflow integration and add acknowledgment projects

* chore: format
This commit is contained in:
JeffJiang
2025-05-28 14:13:46 +08:00
committed by GitHub
parent 0565ab6d27
commit 462752b462
43 changed files with 1172 additions and 181 deletions

View File

@@ -5,19 +5,22 @@ import base64
import json
import logging
import os
from typing import List, cast
from typing import Annotated, List, cast
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage, BaseMessage
from langgraph.types import Command
from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory
from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph
from src.prose.graph.builder import build_graph as build_prose_graph
from src.rag.builder import build_retriever
from src.rag.retriever import Resource
from src.server.chat_request import (
ChatMessage,
ChatRequest,
@@ -28,6 +31,11 @@ from src.server.chat_request import (
)
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
from src.server.mcp_utils import load_mcp_tools
from src.server.rag_request import (
RAGConfigResponse,
RAGResourceRequest,
RAGResourcesResponse,
)
from src.tools import VolcengineTTS
logger = logging.getLogger(__name__)
@@ -59,6 +67,7 @@ async def chat_stream(request: ChatRequest):
_astream_workflow_generator(
request.model_dump()["messages"],
thread_id,
request.resources,
request.max_plan_iterations,
request.max_step_num,
request.max_search_results,
@@ -74,6 +83,7 @@ async def chat_stream(request: ChatRequest):
async def _astream_workflow_generator(
messages: List[ChatMessage],
thread_id: str,
resources: List[Resource],
max_plan_iterations: int,
max_step_num: int,
max_search_results: int,
@@ -101,6 +111,7 @@ async def _astream_workflow_generator(
input_,
config={
"thread_id": thread_id,
"resources": resources,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
"max_search_results": max_search_results,
@@ -319,3 +330,18 @@ async def mcp_server_metadata(request: MCPServerMetadataRequest):
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
raise
@app.get("/api/rag/config", response_model=RAGConfigResponse)
async def rag_config():
"""Get the config of the RAG."""
return RAGConfigResponse(provider=SELECTED_RAG_PROVIDER)
@app.get("/api/rag/resources", response_model=RAGResourcesResponse)
async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
"""Get the resources of the RAG."""
retriever = build_retriever()
if retriever:
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
return RAGResourcesResponse(resources=[])

View File

@@ -5,6 +5,8 @@ from typing import List, Optional, Union
from pydantic import BaseModel, Field
from src.rag.retriever import Resource
class ContentItem(BaseModel):
type: str = Field(..., description="The type of content (text, image, etc.)")
@@ -28,6 +30,9 @@ class ChatRequest(BaseModel):
messages: Optional[List[ChatMessage]] = Field(
[], description="History of messages between the user and the assistant"
)
resources: Optional[List[Resource]] = Field(
[], description="Resources to be used for the research"
)
debug: Optional[bool] = Field(False, description="Whether to enable debug logging")
thread_id: Optional[str] = Field(
"__default__", description="A specific conversation identifier"

25
src/server/rag_request.py Normal file
View File

@@ -0,0 +1,25 @@
from pydantic import BaseModel, Field
from src.rag.retriever import Resource
class RAGConfigResponse(BaseModel):
"""Response model for RAG config."""
provider: str | None = Field(
None, description="The provider of the RAG, default is ragflow"
)
class RAGResourceRequest(BaseModel):
"""Request model for RAG resource."""
query: str | None = Field(
None, description="The query of the resource need to be searched"
)
class RAGResourcesResponse(BaseModel):
"""Response model for RAG resources."""
resources: list[Resource] = Field(..., description="The resources of the RAG")