feat: Enhance chat streaming and tool call processing (#498)

* feat: Enhance chat streaming and tool call processing

- Added support for MongoDB checkpointer in the chat streaming workflow.
- Introduced functions to process tool call chunks and sanitize arguments.
- Improved event message creation with additional metadata.
- Enhanced error handling for JSON serialization in event messages.
- Updated the frontend to convert escaped characters in tool call arguments.
- Refactored the workflow input preparation and initial message processing.
- Added new dependencies for MongoDB integration and tool argument sanitization.

* fix: Update MongoDB checkpointer configuration to use LANGGRAPH_CHECKPOINT_DB_URL

* feat: Add support for Postgres checkpointing and update README with database recommendations

* feat: Implement checkpoint saver functionality and update MongoDB connection handling

* refactor: Improve code formatting and readability in app.py and json_utils.py

* refactor: Clean up commented code and improve formatting in server.py

* refactor: Remove unused imports and improve code organization in app.py

* refactor: Improve code organization and remove unnecessary comments in app.py

* chore: use langgraph-checkpoint-postgres==2.0.21 to avoid the JSON convert issue in the latest version, implement chat stream persistant with Postgres

* feat: add MongoDB and PostgreSQL support for LangGraph checkpointing, enhance environment variable handling

* fix: update comments for clarity on Windows event loop policy

* chore: remove empty code changes in MongoDB and PostgreSQL checkpoint tests

* chore: clean up unused imports and code in checkpoint-related files

* chore: remove empty code changes in test_checkpoint.py

* chore: remove empty code changes in test_checkpoint.py

* chore: remove empty code changes in test_checkpoint.py

* test: update status code assertions in MCP endpoint tests to allow for 403 responses

* test: update MCP endpoint tests to assert specific status codes and enable MCP server configuration

* chore: remove unnecessary environment variables from unittest workflow

* fix: invert condition for MCP server configuration check to raise 403 when disabled

* chore: remove pymongo from test dependencies in uv.lock

* chore:  optimize the _get_agent_name method

* test: enhance ChatStreamManager tests for PostgreSQL and MongoDB initialization

* test: add persistence tests for ChatStreamManager with PostgreSQL and MongoDB

* test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB

* test: enhance persistence tests for ChatStreamManager with PostgreSQL and MongoDB to verify message aggregation

* test: add unit tests for ChatStreamManager with PostgreSQL and MongoDB

* test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB

* test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
CHANGXUBO
2025-08-16 21:03:12 +08:00
committed by GitHub
parent d65b8f8fcc
commit 1bfec3ad05
12 changed files with 1558 additions and 119 deletions

View File

@@ -4,7 +4,6 @@
import base64
import json
import logging
import os
from typing import Annotated, List, cast
from uuid import uuid4
@@ -13,8 +12,12 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
from langgraph.types import Command
from langgraph.store.memory import InMemoryStore
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
from src.config.configuration import get_recursion_limit
from src.config.configuration import get_recursion_limit, get_bool_env, get_str_env
from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory
@@ -42,6 +45,8 @@ from src.server.rag_request import (
RAGResourcesResponse,
)
from src.tools import VolcengineTTS
from src.graph.checkpoint import chat_stream_message
from src.utils.json_utils import sanitize_args
logger = logging.getLogger(__name__)
@@ -56,7 +61,7 @@ app = FastAPI(
# Add CORS middleware
# It's recommended to load the allowed origins from an environment variable
# for better security and flexibility across different environments.
allowed_origins_str = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000")
allowed_origins_str = get_str_env("ALLOWED_ORIGINS", "http://localhost:3000")
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",")]
logger.info(f"Allowed origins: {allowed_origins}")
@@ -68,18 +73,14 @@ app.add_middleware(
allow_methods=["GET", "POST", "OPTIONS"], # Use the configured list of methods
allow_headers=["*"], # Now allow all headers, but can be restricted further
)
in_memory_store = InMemoryStore()
graph = build_graph_with_memory()
@app.post("/api/chat/stream")
async def chat_stream(request: ChatRequest):
# Check if MCP server configuration is enabled
mcp_enabled = os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() in [
"true",
"1",
"yes",
]
mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)
# Validate MCP settings if provided
if request.mcp_settings and not mcp_enabled:
@@ -91,6 +92,7 @@ async def chat_stream(request: ChatRequest):
thread_id = request.thread_id
if thread_id == "__default__":
thread_id = str(uuid4())
return StreamingResponse(
_astream_workflow_generator(
request.model_dump()["messages"],
@@ -110,6 +112,154 @@ async def chat_stream(request: ChatRequest):
)
def _process_tool_call_chunks(tool_call_chunks):
"""Process tool call chunks and sanitize arguments."""
chunks = []
for chunk in tool_call_chunks:
chunks.append(
{
"name": chunk.get("name", ""),
"args": sanitize_args(chunk.get("args", "")),
"id": chunk.get("id", ""),
"index": chunk.get("index", 0),
"type": chunk.get("type", ""),
}
)
return chunks
def _get_agent_name(agent, message_metadata):
"""Extract agent name from agent tuple."""
agent_name = "unknown"
if agent and len(agent) > 0:
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0]
else:
agent_name = message_metadata.get("langgraph_node", "unknown")
return agent_name
def _create_event_stream_message(
message_chunk, message_metadata, thread_id, agent_name
):
"""Create base event stream message."""
event_stream_message = {
"thread_id": thread_id,
"agent": agent_name,
"id": message_chunk.id,
"role": "assistant",
"checkpoint_ns": message_metadata.get("checkpoint_ns", ""),
"langgraph_node": message_metadata.get("langgraph_node", ""),
"langgraph_path": message_metadata.get("langgraph_path", ""),
"langgraph_step": message_metadata.get("langgraph_step", ""),
"content": message_chunk.content,
}
# Add optional fields
if message_chunk.additional_kwargs.get("reasoning_content"):
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
"reasoning_content"
]
if message_chunk.response_metadata.get("finish_reason"):
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
"finish_reason"
)
return event_stream_message
def _create_interrupt_event(thread_id, event_data):
"""Create interrupt event."""
return _make_event(
"interrupt",
{
"thread_id": thread_id,
"id": event_data["__interrupt__"][0].ns[0],
"role": "assistant",
"content": event_data["__interrupt__"][0].value,
"finish_reason": "interrupt",
"options": [
{"text": "Edit plan", "value": "edit_plan"},
{"text": "Start research", "value": "accepted"},
],
},
)
def _process_initial_messages(message, thread_id):
"""Process initial messages and yield formatted events."""
json_data = json.dumps(
{
"thread_id": thread_id,
"id": "run--" + message.get("id", uuid4().hex),
"role": "user",
"content": message.get("content", ""),
},
ensure_ascii=False,
separators=(",", ":"),
)
chat_stream_message(
thread_id, f"event: message_chunk\ndata: {json_data}\n\n", "none"
)
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
"""Process a single message chunk and yield appropriate events."""
agent_name = _get_agent_name(agent, message_metadata)
event_stream_message = _create_event_stream_message(
message_chunk, message_metadata, thread_id, agent_name
)
if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
yield _make_event("tool_call_result", event_stream_message)
elif isinstance(message_chunk, AIMessageChunk):
# AI Message - Raw message tokens
if message_chunk.tool_calls:
# AI Message - Tool Call
event_stream_message["tool_calls"] = message_chunk.tool_calls
event_stream_message["tool_call_chunks"] = message_chunk.tool_call_chunks
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
message_chunk.tool_call_chunks
)
yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
message_chunk.tool_call_chunks
)
yield _make_event("tool_call_chunks", event_stream_message)
else:
# AI Message - Raw message tokens
yield _make_event("message_chunk", event_stream_message)
async def _stream_graph_events(
graph_instance, workflow_input, workflow_config, thread_id
):
"""Stream events from the graph and process them."""
async for agent, _, event_data in graph_instance.astream(
workflow_input,
config=workflow_config,
stream_mode=["messages", "updates"],
subgraphs=True,
):
if isinstance(event_data, dict):
if "__interrupt__" in event_data:
yield _create_interrupt_event(thread_id, event_data)
continue
message_chunk, message_metadata = cast(
tuple[BaseMessage, dict[str, any]], event_data
)
async for event in _process_message_chunk(
message_chunk, message_metadata, thread_id, agent
):
yield event
async def _astream_workflow_generator(
messages: List[dict],
thread_id: str,
@@ -124,7 +274,13 @@ async def _astream_workflow_generator(
report_style: ReportStyle,
enable_deep_thinking: bool,
):
input_ = {
# Process initial messages
for message in messages:
if isinstance(message, dict) and "content" in message:
_process_initial_messages(message, thread_id)
# Prepare workflow input
workflow_input = {
"messages": messages,
"plan_iterations": 0,
"final_report": "",
@@ -134,112 +290,105 @@ async def _astream_workflow_generator(
"enable_background_investigation": enable_background_investigation,
"research_topic": messages[-1]["content"] if messages else "",
}
if not auto_accepted_plan and interrupt_feedback:
resume_msg = f"[{interrupt_feedback}]"
# add the last message to the resume message
if messages:
resume_msg += f" {messages[-1]['content']}"
input_ = Command(resume=resume_msg)
async for agent, _, event_data in graph.astream(
input_,
config={
"thread_id": thread_id,
"resources": resources,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
"max_search_results": max_search_results,
"mcp_settings": mcp_settings,
"report_style": report_style.value,
"enable_deep_thinking": enable_deep_thinking,
"recursion_limit": get_recursion_limit(),
},
stream_mode=["messages", "updates"],
subgraphs=True,
):
if isinstance(event_data, dict):
if "__interrupt__" in event_data:
yield _make_event(
"interrupt",
{
"thread_id": thread_id,
"id": event_data["__interrupt__"][0].ns[0],
"role": "assistant",
"content": event_data["__interrupt__"][0].value,
"finish_reason": "interrupt",
"options": [
{"text": "Edit plan", "value": "edit_plan"},
{"text": "Start research", "value": "accepted"},
],
},
)
continue
message_chunk, message_metadata = cast(
tuple[BaseMessage, dict[str, any]], event_data
)
# Handle empty agent tuple gracefully
agent_name = "planner"
if agent and len(agent) > 0:
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0]
event_stream_message: dict[str, any] = {
"thread_id": thread_id,
"agent": agent_name,
"id": message_chunk.id,
"role": "assistant",
"content": message_chunk.content,
}
if message_chunk.additional_kwargs.get("reasoning_content"):
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
"reasoning_content"
]
if message_chunk.response_metadata.get("finish_reason"):
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
"finish_reason"
)
if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
yield _make_event("tool_call_result", event_stream_message)
elif isinstance(message_chunk, AIMessageChunk):
# AI Message - Raw message tokens
if message_chunk.tool_calls:
# AI Message - Tool Call
event_stream_message["tool_calls"] = message_chunk.tool_calls
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_call_chunks", event_stream_message)
else:
# AI Message - Raw message tokens
yield _make_event("message_chunk", event_stream_message)
workflow_input = Command(resume=resume_msg)
# Prepare workflow config
workflow_config = {
"thread_id": thread_id,
"resources": resources,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
"max_search_results": max_search_results,
"mcp_settings": mcp_settings,
"report_style": report_style.value,
"enable_deep_thinking": enable_deep_thinking,
"recursion_limit": get_recursion_limit(),
}
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
# Handle checkpointer if configured
connection_kwargs = {
"autocommit": True,
"row_factory": "dict_row",
"prepare_threshold": 0,
}
if checkpoint_saver and checkpoint_url != "":
if checkpoint_url.startswith("postgresql://"):
logger.info("start async postgres checkpointer.")
async with AsyncConnectionPool(
checkpoint_url, kwargs=connection_kwargs
) as conn:
checkpointer = AsyncPostgresSaver(conn)
await checkpointer.setup()
graph.checkpointer = checkpointer
graph.store = in_memory_store
async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id
):
yield event
if checkpoint_url.startswith("mongodb://"):
logger.info("start async mongodb checkpointer.")
async with AsyncMongoDBSaver.from_conn_string(
checkpoint_url
) as checkpointer:
graph.checkpointer = checkpointer
graph.store = in_memory_store
async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id
):
yield event
else:
# Use graph without MongoDB checkpointer
async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id
):
yield event
def _make_event(event_type: str, data: dict[str, any]):
if data.get("content") == "":
data.pop("content")
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
# Ensure JSON serialization with proper encoding
try:
json_data = json.dumps(data, ensure_ascii=False)
finish_reason = data.get("finish_reason", "")
chat_stream_message(
data.get("thread_id", ""),
f"event: {event_type}\ndata: {json_data}\n\n",
finish_reason,
)
return f"event: {event_type}\ndata: {json_data}\n\n"
except (TypeError, ValueError) as e:
logger.error(f"Error serializing event data: {e}")
# Return a safe error event
error_data = json.dumps({"error": "Serialization failed"}, ensure_ascii=False)
return f"event: error\ndata: {error_data}\n\n"
@app.post("/api/tts")
async def text_to_speech(request: TTSRequest):
"""Convert text to speech using volcengine TTS API."""
app_id = os.getenv("VOLCENGINE_TTS_APPID", "")
app_id = get_str_env("VOLCENGINE_TTS_APPID", "")
if not app_id:
raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set")
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "")
access_token = get_str_env("VOLCENGINE_TTS_ACCESS_TOKEN", "")
if not access_token:
raise HTTPException(
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
)
try:
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
cluster = get_str_env("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
voice_type = get_str_env("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
tts_client = VolcengineTTS(
appid=app_id,
@@ -382,11 +531,7 @@ async def enhance_prompt(request: EnhancePromptRequest):
async def mcp_server_metadata(request: MCPServerMetadataRequest):
"""Get information about an MCP server."""
# Check if MCP server configuration is enabled
if os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() not in [
"true",
"1",
"yes",
]:
if not get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False):
raise HTTPException(
status_code=403,
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",