mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-25 15:04:46 +08:00
feat: Enhance chat streaming and tool call processing (#498)
* feat: Enhance chat streaming and tool call processing - Added support for MongoDB checkpointer in the chat streaming workflow. - Introduced functions to process tool call chunks and sanitize arguments. - Improved event message creation with additional metadata. - Enhanced error handling for JSON serialization in event messages. - Updated the frontend to convert escaped characters in tool call arguments. - Refactored the workflow input preparation and initial message processing. - Added new dependencies for MongoDB integration and tool argument sanitization. * fix: Update MongoDB checkpointer configuration to use LANGGRAPH_CHECKPOINT_DB_URL * feat: Add support for Postgres checkpointing and update README with database recommendations * feat: Implement checkpoint saver functionality and update MongoDB connection handling * refactor: Improve code formatting and readability in app.py and json_utils.py * refactor: Clean up commented code and improve formatting in server.py * refactor: Remove unused imports and improve code organization in app.py * refactor: Improve code organization and remove unnecessary comments in app.py * chore: use langgraph-checkpoint-postgres==2.0.21 to avoid the JSON convert issue in the latest version, implement chat stream persistant with Postgres * feat: add MongoDB and PostgreSQL support for LangGraph checkpointing, enhance environment variable handling * fix: update comments for clarity on Windows event loop policy * chore: remove empty code changes in MongoDB and PostgreSQL checkpoint tests * chore: clean up unused imports and code in checkpoint-related files * chore: remove empty code changes in test_checkpoint.py * chore: remove empty code changes in test_checkpoint.py * chore: remove empty code changes in test_checkpoint.py * test: update status code assertions in MCP endpoint tests to allow for 403 responses * test: update MCP endpoint tests to assert specific status codes and enable MCP server configuration * chore: remove unnecessary environment variables from unittest workflow * fix: invert condition for MCP server configuration check to raise 403 when disabled * chore: remove pymongo from test dependencies in uv.lock * chore: optimize the _get_agent_name method * test: enhance ChatStreamManager tests for PostgreSQL and MongoDB initialization * test: add persistence tests for ChatStreamManager with PostgreSQL and MongoDB * test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB * test: enhance persistence tests for ChatStreamManager with PostgreSQL and MongoDB to verify message aggregation * test: add unit tests for ChatStreamManager with PostgreSQL and MongoDB * test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB * test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated, List, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -13,8 +12,12 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from src.config.configuration import get_recursion_limit
|
||||
from src.config.configuration import get_recursion_limit, get_bool_env, get_str_env
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.graph.builder import build_graph_with_memory
|
||||
@@ -42,6 +45,8 @@ from src.server.rag_request import (
|
||||
RAGResourcesResponse,
|
||||
)
|
||||
from src.tools import VolcengineTTS
|
||||
from src.graph.checkpoint import chat_stream_message
|
||||
from src.utils.json_utils import sanitize_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -56,7 +61,7 @@ app = FastAPI(
|
||||
# Add CORS middleware
|
||||
# It's recommended to load the allowed origins from an environment variable
|
||||
# for better security and flexibility across different environments.
|
||||
allowed_origins_str = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000")
|
||||
allowed_origins_str = get_str_env("ALLOWED_ORIGINS", "http://localhost:3000")
|
||||
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",")]
|
||||
|
||||
logger.info(f"Allowed origins: {allowed_origins}")
|
||||
@@ -68,18 +73,14 @@ app.add_middleware(
|
||||
allow_methods=["GET", "POST", "OPTIONS"], # Use the configured list of methods
|
||||
allow_headers=["*"], # Now allow all headers, but can be restricted further
|
||||
)
|
||||
|
||||
in_memory_store = InMemoryStore()
|
||||
graph = build_graph_with_memory()
|
||||
|
||||
|
||||
@app.post("/api/chat/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
# Check if MCP server configuration is enabled
|
||||
mcp_enabled = os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() in [
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
]
|
||||
mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)
|
||||
|
||||
# Validate MCP settings if provided
|
||||
if request.mcp_settings and not mcp_enabled:
|
||||
@@ -91,6 +92,7 @@ async def chat_stream(request: ChatRequest):
|
||||
thread_id = request.thread_id
|
||||
if thread_id == "__default__":
|
||||
thread_id = str(uuid4())
|
||||
|
||||
return StreamingResponse(
|
||||
_astream_workflow_generator(
|
||||
request.model_dump()["messages"],
|
||||
@@ -110,6 +112,154 @@ async def chat_stream(request: ChatRequest):
|
||||
)
|
||||
|
||||
|
||||
def _process_tool_call_chunks(tool_call_chunks):
|
||||
"""Process tool call chunks and sanitize arguments."""
|
||||
chunks = []
|
||||
for chunk in tool_call_chunks:
|
||||
chunks.append(
|
||||
{
|
||||
"name": chunk.get("name", ""),
|
||||
"args": sanitize_args(chunk.get("args", "")),
|
||||
"id": chunk.get("id", ""),
|
||||
"index": chunk.get("index", 0),
|
||||
"type": chunk.get("type", ""),
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
def _get_agent_name(agent, message_metadata):
|
||||
"""Extract agent name from agent tuple."""
|
||||
agent_name = "unknown"
|
||||
if agent and len(agent) > 0:
|
||||
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0]
|
||||
else:
|
||||
agent_name = message_metadata.get("langgraph_node", "unknown")
|
||||
return agent_name
|
||||
|
||||
|
||||
def _create_event_stream_message(
|
||||
message_chunk, message_metadata, thread_id, agent_name
|
||||
):
|
||||
"""Create base event stream message."""
|
||||
event_stream_message = {
|
||||
"thread_id": thread_id,
|
||||
"agent": agent_name,
|
||||
"id": message_chunk.id,
|
||||
"role": "assistant",
|
||||
"checkpoint_ns": message_metadata.get("checkpoint_ns", ""),
|
||||
"langgraph_node": message_metadata.get("langgraph_node", ""),
|
||||
"langgraph_path": message_metadata.get("langgraph_path", ""),
|
||||
"langgraph_step": message_metadata.get("langgraph_step", ""),
|
||||
"content": message_chunk.content,
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if message_chunk.additional_kwargs.get("reasoning_content"):
|
||||
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
|
||||
"reasoning_content"
|
||||
]
|
||||
|
||||
if message_chunk.response_metadata.get("finish_reason"):
|
||||
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
|
||||
"finish_reason"
|
||||
)
|
||||
|
||||
return event_stream_message
|
||||
|
||||
|
||||
def _create_interrupt_event(thread_id, event_data):
|
||||
"""Create interrupt event."""
|
||||
return _make_event(
|
||||
"interrupt",
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"id": event_data["__interrupt__"][0].ns[0],
|
||||
"role": "assistant",
|
||||
"content": event_data["__interrupt__"][0].value,
|
||||
"finish_reason": "interrupt",
|
||||
"options": [
|
||||
{"text": "Edit plan", "value": "edit_plan"},
|
||||
{"text": "Start research", "value": "accepted"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _process_initial_messages(message, thread_id):
|
||||
"""Process initial messages and yield formatted events."""
|
||||
json_data = json.dumps(
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"id": "run--" + message.get("id", uuid4().hex),
|
||||
"role": "user",
|
||||
"content": message.get("content", ""),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
chat_stream_message(
|
||||
thread_id, f"event: message_chunk\ndata: {json_data}\n\n", "none"
|
||||
)
|
||||
|
||||
|
||||
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
|
||||
"""Process a single message chunk and yield appropriate events."""
|
||||
agent_name = _get_agent_name(agent, message_metadata)
|
||||
event_stream_message = _create_event_stream_message(
|
||||
message_chunk, message_metadata, thread_id, agent_name
|
||||
)
|
||||
|
||||
if isinstance(message_chunk, ToolMessage):
|
||||
# Tool Message - Return the result of the tool call
|
||||
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
|
||||
yield _make_event("tool_call_result", event_stream_message)
|
||||
elif isinstance(message_chunk, AIMessageChunk):
|
||||
# AI Message - Raw message tokens
|
||||
if message_chunk.tool_calls:
|
||||
# AI Message - Tool Call
|
||||
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
||||
event_stream_message["tool_call_chunks"] = message_chunk.tool_call_chunks
|
||||
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
yield _make_event("tool_calls", event_stream_message)
|
||||
elif message_chunk.tool_call_chunks:
|
||||
# AI Message - Tool Call Chunks
|
||||
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
yield _make_event("tool_call_chunks", event_stream_message)
|
||||
else:
|
||||
# AI Message - Raw message tokens
|
||||
yield _make_event("message_chunk", event_stream_message)
|
||||
|
||||
|
||||
async def _stream_graph_events(
|
||||
graph_instance, workflow_input, workflow_config, thread_id
|
||||
):
|
||||
"""Stream events from the graph and process them."""
|
||||
async for agent, _, event_data in graph_instance.astream(
|
||||
workflow_input,
|
||||
config=workflow_config,
|
||||
stream_mode=["messages", "updates"],
|
||||
subgraphs=True,
|
||||
):
|
||||
if isinstance(event_data, dict):
|
||||
if "__interrupt__" in event_data:
|
||||
yield _create_interrupt_event(thread_id, event_data)
|
||||
continue
|
||||
|
||||
message_chunk, message_metadata = cast(
|
||||
tuple[BaseMessage, dict[str, any]], event_data
|
||||
)
|
||||
|
||||
async for event in _process_message_chunk(
|
||||
message_chunk, message_metadata, thread_id, agent
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def _astream_workflow_generator(
|
||||
messages: List[dict],
|
||||
thread_id: str,
|
||||
@@ -124,7 +274,13 @@ async def _astream_workflow_generator(
|
||||
report_style: ReportStyle,
|
||||
enable_deep_thinking: bool,
|
||||
):
|
||||
input_ = {
|
||||
# Process initial messages
|
||||
for message in messages:
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
_process_initial_messages(message, thread_id)
|
||||
|
||||
# Prepare workflow input
|
||||
workflow_input = {
|
||||
"messages": messages,
|
||||
"plan_iterations": 0,
|
||||
"final_report": "",
|
||||
@@ -134,112 +290,105 @@ async def _astream_workflow_generator(
|
||||
"enable_background_investigation": enable_background_investigation,
|
||||
"research_topic": messages[-1]["content"] if messages else "",
|
||||
}
|
||||
|
||||
if not auto_accepted_plan and interrupt_feedback:
|
||||
resume_msg = f"[{interrupt_feedback}]"
|
||||
# add the last message to the resume message
|
||||
if messages:
|
||||
resume_msg += f" {messages[-1]['content']}"
|
||||
input_ = Command(resume=resume_msg)
|
||||
async for agent, _, event_data in graph.astream(
|
||||
input_,
|
||||
config={
|
||||
"thread_id": thread_id,
|
||||
"resources": resources,
|
||||
"max_plan_iterations": max_plan_iterations,
|
||||
"max_step_num": max_step_num,
|
||||
"max_search_results": max_search_results,
|
||||
"mcp_settings": mcp_settings,
|
||||
"report_style": report_style.value,
|
||||
"enable_deep_thinking": enable_deep_thinking,
|
||||
"recursion_limit": get_recursion_limit(),
|
||||
},
|
||||
stream_mode=["messages", "updates"],
|
||||
subgraphs=True,
|
||||
):
|
||||
if isinstance(event_data, dict):
|
||||
if "__interrupt__" in event_data:
|
||||
yield _make_event(
|
||||
"interrupt",
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"id": event_data["__interrupt__"][0].ns[0],
|
||||
"role": "assistant",
|
||||
"content": event_data["__interrupt__"][0].value,
|
||||
"finish_reason": "interrupt",
|
||||
"options": [
|
||||
{"text": "Edit plan", "value": "edit_plan"},
|
||||
{"text": "Start research", "value": "accepted"},
|
||||
],
|
||||
},
|
||||
)
|
||||
continue
|
||||
message_chunk, message_metadata = cast(
|
||||
tuple[BaseMessage, dict[str, any]], event_data
|
||||
)
|
||||
# Handle empty agent tuple gracefully
|
||||
agent_name = "planner"
|
||||
if agent and len(agent) > 0:
|
||||
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0]
|
||||
event_stream_message: dict[str, any] = {
|
||||
"thread_id": thread_id,
|
||||
"agent": agent_name,
|
||||
"id": message_chunk.id,
|
||||
"role": "assistant",
|
||||
"content": message_chunk.content,
|
||||
}
|
||||
if message_chunk.additional_kwargs.get("reasoning_content"):
|
||||
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
|
||||
"reasoning_content"
|
||||
]
|
||||
if message_chunk.response_metadata.get("finish_reason"):
|
||||
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
|
||||
"finish_reason"
|
||||
)
|
||||
if isinstance(message_chunk, ToolMessage):
|
||||
# Tool Message - Return the result of the tool call
|
||||
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
|
||||
yield _make_event("tool_call_result", event_stream_message)
|
||||
elif isinstance(message_chunk, AIMessageChunk):
|
||||
# AI Message - Raw message tokens
|
||||
if message_chunk.tool_calls:
|
||||
# AI Message - Tool Call
|
||||
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
||||
event_stream_message["tool_call_chunks"] = (
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
yield _make_event("tool_calls", event_stream_message)
|
||||
elif message_chunk.tool_call_chunks:
|
||||
# AI Message - Tool Call Chunks
|
||||
event_stream_message["tool_call_chunks"] = (
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
yield _make_event("tool_call_chunks", event_stream_message)
|
||||
else:
|
||||
# AI Message - Raw message tokens
|
||||
yield _make_event("message_chunk", event_stream_message)
|
||||
workflow_input = Command(resume=resume_msg)
|
||||
|
||||
# Prepare workflow config
|
||||
workflow_config = {
|
||||
"thread_id": thread_id,
|
||||
"resources": resources,
|
||||
"max_plan_iterations": max_plan_iterations,
|
||||
"max_step_num": max_step_num,
|
||||
"max_search_results": max_search_results,
|
||||
"mcp_settings": mcp_settings,
|
||||
"report_style": report_style.value,
|
||||
"enable_deep_thinking": enable_deep_thinking,
|
||||
"recursion_limit": get_recursion_limit(),
|
||||
}
|
||||
|
||||
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
|
||||
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
|
||||
# Handle checkpointer if configured
|
||||
connection_kwargs = {
|
||||
"autocommit": True,
|
||||
"row_factory": "dict_row",
|
||||
"prepare_threshold": 0,
|
||||
}
|
||||
if checkpoint_saver and checkpoint_url != "":
|
||||
if checkpoint_url.startswith("postgresql://"):
|
||||
logger.info("start async postgres checkpointer.")
|
||||
async with AsyncConnectionPool(
|
||||
checkpoint_url, kwargs=connection_kwargs
|
||||
) as conn:
|
||||
checkpointer = AsyncPostgresSaver(conn)
|
||||
await checkpointer.setup()
|
||||
graph.checkpointer = checkpointer
|
||||
graph.store = in_memory_store
|
||||
async for event in _stream_graph_events(
|
||||
graph, workflow_input, workflow_config, thread_id
|
||||
):
|
||||
yield event
|
||||
|
||||
if checkpoint_url.startswith("mongodb://"):
|
||||
logger.info("start async mongodb checkpointer.")
|
||||
async with AsyncMongoDBSaver.from_conn_string(
|
||||
checkpoint_url
|
||||
) as checkpointer:
|
||||
graph.checkpointer = checkpointer
|
||||
graph.store = in_memory_store
|
||||
async for event in _stream_graph_events(
|
||||
graph, workflow_input, workflow_config, thread_id
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
# Use graph without MongoDB checkpointer
|
||||
async for event in _stream_graph_events(
|
||||
graph, workflow_input, workflow_config, thread_id
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
def _make_event(event_type: str, data: dict[str, any]):
|
||||
if data.get("content") == "":
|
||||
data.pop("content")
|
||||
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
# Ensure JSON serialization with proper encoding
|
||||
try:
|
||||
json_data = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
finish_reason = data.get("finish_reason", "")
|
||||
chat_stream_message(
|
||||
data.get("thread_id", ""),
|
||||
f"event: {event_type}\ndata: {json_data}\n\n",
|
||||
finish_reason,
|
||||
)
|
||||
|
||||
return f"event: {event_type}\ndata: {json_data}\n\n"
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"Error serializing event data: {e}")
|
||||
# Return a safe error event
|
||||
error_data = json.dumps({"error": "Serialization failed"}, ensure_ascii=False)
|
||||
return f"event: error\ndata: {error_data}\n\n"
|
||||
|
||||
|
||||
@app.post("/api/tts")
|
||||
async def text_to_speech(request: TTSRequest):
|
||||
"""Convert text to speech using volcengine TTS API."""
|
||||
app_id = os.getenv("VOLCENGINE_TTS_APPID", "")
|
||||
app_id = get_str_env("VOLCENGINE_TTS_APPID", "")
|
||||
if not app_id:
|
||||
raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set")
|
||||
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "")
|
||||
access_token = get_str_env("VOLCENGINE_TTS_ACCESS_TOKEN", "")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
|
||||
)
|
||||
|
||||
try:
|
||||
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
|
||||
voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
|
||||
cluster = get_str_env("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
|
||||
voice_type = get_str_env("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
|
||||
|
||||
tts_client = VolcengineTTS(
|
||||
appid=app_id,
|
||||
@@ -382,11 +531,7 @@ async def enhance_prompt(request: EnhancePromptRequest):
|
||||
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
||||
"""Get information about an MCP server."""
|
||||
# Check if MCP server configuration is enabled
|
||||
if os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() not in [
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
]:
|
||||
if not get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",
|
||||
|
||||
Reference in New Issue
Block a user