diff --git a/.env.example b/.env.example index 9c1e319..2782a06 100644 --- a/.env.example +++ b/.env.example @@ -55,3 +55,10 @@ VOLCENGINE_TTS_ACCESS_TOKEN=xxx # [!NOTE] # For model settings and other configurations, please refer to `docs/configuration_guide.md` + +# Option, for langgraph mongodb checkpointer +# Enable LangGraph checkpoint saver, supports MongoDB, Postgres +#LANGGRAPH_CHECKPOINT_SAVER=true +# Set the database URL for saving checkpoints +#LANGGRAPH_CHECKPOINT_DB_URL="ongodb://localhost:27017/ +#LANGGRAPH_CHECKPOINT_DB_URL=postgresql://localhost:5432/postgres \ No newline at end of file diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index c663735..f1fde21 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -12,6 +12,31 @@ permissions: jobs: test: runs-on: ubuntu-latest + services: + postgres: + image: postgres:15 + env: + POSTGRES_DB: checkpointing_db + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: ["5432:5432"] + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + mongodb: + image: mongo:6 + env: + MONGO_INITDB_ROOT_USERNAME: admin + MONGO_INITDB_ROOT_PASSWORD: admin + MONGO_INITDB_DATABASE: checkpointing_db + ports: ["27017:27017"] + options: >- + --health-cmd "mongosh --eval 'db.runCommand(\"ping\").ok'" + --health-interval 10s + --health-timeout 5s + --health-retries 3 steps: - uses: actions/checkout@v3 diff --git a/README.md b/README.md index 629b1b3..ad7de29 100644 --- a/README.md +++ b/README.md @@ -386,6 +386,33 @@ DeerFlow supports LangSmith tracing to help you debug and monitor your workflows This will enable trace visualization in LangGraph Studio and send your traces to LangSmith for monitoring and analysis. +### Checkpointing +1. Postgres and MonogDB implementation of LangGraph checkpoint saver. +2. In-memory store is used to caching the streaming messages before persisting to database, If finish_reason is "stop" or "interrupt", it triggers persistence. +3. Supports saving and loading checkpoints for workflow execution. +4. Supports saving chat stream events for replaying conversations. + +Note: +The latest langgraph-checkpoint-postgres-2.0.23 have checkpointing issue, you can check the open issue:"TypeError: Object of type HumanMessage is not JSON serializable" [https://github.com/langchain-ai/langgraph/issues/5557]. + +To use postgres checkpoint you should install langgraph-checkpoint-postgres-2.0.21 + +The default database and collection will be automatically created if not exists. +Default database: checkpoing_db +Default collection: checkpoint_writes_aio (langgraph checkpoint writes) +Default collection: checkpoints_aio (langgraph checkpoints) +Default collection: chat_streams (chat stream events for replaying conversations) + +You need to set the following environment variables in your `.env` file: + +```bash +# Enable LangGraph checkpoint saver, supports MongoDB, Postgres +LANGGRAPH_CHECKPOINT_SAVER=true +# Set the database URL for saving checkpoints +LANGGRAPH_CHECKPOINT_DB_URL="mongodb://localhost:27017/" +#LANGGRAPH_CHECKPOINT_DB_URL=postgresql://localhost:5432/postgres +``` + ## Docker You can also run this project with Docker. diff --git a/pyproject.toml b/pyproject.toml index 52e67f3..f0b21a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ "langchain-mcp-adapters>=0.0.9", "langchain-deepseek>=0.1.3", "wikipedia>=1.4.0", + "langgraph-checkpoint-mongodb>=0.1.4", + "langgraph-checkpoint-postgres==2.0.21", ] [project.optional-dependencies] diff --git a/server.py b/server.py index 04461aa..727c840 100644 --- a/server.py +++ b/server.py @@ -4,7 +4,8 @@ """ Server script for running the DeerFlow API. """ - +import os +import asyncio import argparse import logging import signal @@ -19,6 +20,17 @@ logging.basicConfig( logger = logging.getLogger(__name__) +# To ensure compatibility with Windows event loop issues when using Uvicorn and Asyncio Checkpointer, +# This is necessary because some libraries expect a selector-based event loop. +# This is a workaround for issues with Uvicorn and Watchdog on Windows. +# See: +# Since Python 3.8 the default on Windows is the Proactor event loop, +# which lacks add_reader/add_writer and can break libraries that expect selector-based I/O (e.g., some Uvicorn/Watchdog/stdio integrations). +# For compatibility, this forces the selector loop. +if os.name == "nt": + logger.info("Setting Windows event loop policy for asyncio") + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + def handle_shutdown(signum, frame): """Handle graceful shutdown on SIGTERM/SIGINT""" diff --git a/src/config/configuration.py b/src/config/configuration.py index 4ffc1bf..ba39af3 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -13,6 +13,33 @@ from src.config.report_style import ReportStyle logger = logging.getLogger(__name__) +_TRUTHY = {"1", "true", "yes", "y", "on"} + + +def get_bool_env(name: str, default: bool = False) -> bool: + val = os.getenv(name) + if val is None: + return default + return str(val).strip().lower() in _TRUTHY + + +def get_str_env(name: str, default: str = "") -> str: + val = os.getenv(name) + return default if val is None else str(val).strip() + + +def get_int_env(name: str, default: int = 0) -> int: + val = os.getenv(name) + if val is None: + return default + try: + return int(val.strip()) + except ValueError: + logger.warning( + f"Invalid integer value for {name}: {val}. Using default {default}." + ) + return default + def get_recursion_limit(default: int = 25) -> int: """Get the recursion limit from environment variable or use default. @@ -23,23 +50,15 @@ def get_recursion_limit(default: int = 25) -> int: Returns: int: The recursion limit to use """ - try: - env_value_str = os.getenv("AGENT_RECURSION_LIMIT", str(default)) - parsed_limit = int(env_value_str) + env_value_str = get_str_env("AGENT_RECURSION_LIMIT", str(default)) + parsed_limit = get_int_env("AGENT_RECURSION_LIMIT", default) - if parsed_limit > 0: - logger.info(f"Recursion limit set to: {parsed_limit}") - return parsed_limit - else: - logger.warning( - f"AGENT_RECURSION_LIMIT value '{env_value_str}' (parsed as {parsed_limit}) is not positive. " - f"Using default value {default}." - ) - return default - except ValueError: - raw_env_value = os.getenv("AGENT_RECURSION_LIMIT") + if parsed_limit > 0: + logger.info(f"Recursion limit set to: {parsed_limit}") + return parsed_limit + else: logger.warning( - f"Invalid AGENT_RECURSION_LIMIT value: '{raw_env_value}'. " + f"AGENT_RECURSION_LIMIT value '{env_value_str}' (parsed as {parsed_limit}) is not positive. " f"Using default value {default}." ) return default diff --git a/src/graph/checkpoint.py b/src/graph/checkpoint.py new file mode 100644 index 0000000..a0d9597 --- /dev/null +++ b/src/graph/checkpoint.py @@ -0,0 +1,372 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import json +import logging +import uuid +from datetime import datetime +from typing import List, Optional, Tuple +import psycopg +from psycopg.rows import dict_row +from pymongo import MongoClient +from langgraph.store.memory import InMemoryStore +from src.config.configuration import get_bool_env, get_str_env + + +class ChatStreamManager: + """ + Manages chat stream messages with persistent storage and in-memory caching. + + This class handles the storage and retrieval of chat messages using both + an in-memory store for temporary data and MongoDB or PostgreSQL for persistent storage. + It tracks message chunks and consolidates them when a conversation finishes. + + Attributes: + store (InMemoryStore): In-memory storage for temporary message chunks + mongo_client (MongoClient): MongoDB client connection + mongo_db (Database): MongoDB database instance + postgres_conn (psycopg.Connection): PostgreSQL connection + logger (logging.Logger): Logger instance for this class + """ + + def __init__( + self, checkpoint_saver: bool = False, db_uri: Optional[str] = None + ) -> None: + """ + Initialize the ChatStreamManager with database connections. + + Args: + db_uri: Database connection URI. Supports MongoDB (mongodb://) and PostgreSQL (postgresql://) + If None, uses LANGGRAPH_CHECKPOINT_DB_URL env var or defaults to localhost + """ + self.logger = logging.getLogger(__name__) + self.store = InMemoryStore() + self.checkpoint_saver = checkpoint_saver + # Use provided URI or fall back to environment variable or default + self.db_uri = db_uri + + # Initialize database connections + self.mongo_client = None + self.mongo_db = None + self.postgres_conn = None + + if self.checkpoint_saver: + if self.db_uri.startswith("mongodb://"): + self._init_mongodb() + elif self.db_uri.startswith("postgresql://") or self.db_uri.startswith( + "postgres://" + ): + self._init_postgresql() + else: + self.logger.warning( + f"Unsupported database URI scheme: {self.db_uri}. " + "Supported schemes: mongodb://, postgresql://, postgres://" + ) + else: + self.logger.warning("Checkpoint saver is disabled") + + def _init_mongodb(self) -> None: + """Initialize MongoDB connection.""" + + try: + self.mongo_client = MongoClient(self.db_uri) + self.mongo_db = self.mongo_client.checkpointing_db + # Test connection + self.mongo_client.admin.command("ping") + self.logger.info("Successfully connected to MongoDB") + except Exception as e: + self.logger.error(f"Failed to connect to MongoDB: {e}") + + def _init_postgresql(self) -> None: + """Initialize PostgreSQL connection and create table if needed.""" + + try: + self.postgres_conn = psycopg.connect(self.db_uri, row_factory=dict_row) + self.logger.info("Successfully connected to PostgreSQL") + self._create_chat_streams_table() + except Exception as e: + self.logger.error(f"Failed to connect to PostgreSQL: {e}") + + def _create_chat_streams_table(self) -> None: + """Create the chat_streams table if it doesn't exist.""" + try: + with self.postgres_conn.cursor() as cursor: + create_table_sql = """ + CREATE TABLE IF NOT EXISTS chat_streams ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + thread_id VARCHAR(255) NOT NULL UNIQUE, + messages JSONB NOT NULL, + ts TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_chat_streams_thread_id ON chat_streams(thread_id); + CREATE INDEX IF NOT EXISTS idx_chat_streams_ts ON chat_streams(ts); + """ + cursor.execute(create_table_sql) + self.postgres_conn.commit() + self.logger.info("Chat streams table created/verified successfully") + except Exception as e: + self.logger.error(f"Failed to create chat_streams table: {e}") + if self.postgres_conn: + self.postgres_conn.rollback() + + def process_stream_message( + self, thread_id: str, message: str, finish_reason: str + ) -> bool: + """ + Process and store a chat stream message chunk. + + This method handles individual message chunks during streaming and consolidates + them into a complete message when the stream finishes. Messages are stored + temporarily in memory and permanently in MongoDB when complete. + + Args: + thread_id: Unique identifier for the conversation thread + message: The message content or chunk to store + finish_reason: Reason for message completion ("stop", "interrupt", or partial) + + Returns: + bool: True if message was processed successfully, False otherwise + """ + if not thread_id or not isinstance(thread_id, str): + self.logger.warning("Invalid thread_id provided") + return False + + if not message: + self.logger.warning("Empty message provided") + return False + + try: + # Create namespace for this thread's messages + store_namespace: Tuple[str, str] = ("messages", thread_id) + + # Get or initialize message cursor for tracking chunks + cursor = self.store.get(store_namespace, "cursor") + current_index = 0 + + if cursor is None: + # Initialize cursor for new conversation + self.store.put(store_namespace, "cursor", {"index": 0}) + else: + # Increment index for next chunk + current_index = int(cursor.value.get("index", 0)) + 1 + self.store.put(store_namespace, "cursor", {"index": current_index}) + + # Store the current message chunk + self.store.put(store_namespace, f"chunk_{current_index}", message) + + # Check if conversation is complete and should be persisted + if finish_reason in ("stop", "interrupt"): + return self._persist_complete_conversation( + thread_id, store_namespace, current_index + ) + + return True + + except Exception as e: + self.logger.error( + f"Error processing stream message for thread {thread_id}: {e}" + ) + return False + + def _persist_complete_conversation( + self, thread_id: str, store_namespace: Tuple[str, str], final_index: int + ) -> bool: + """ + Persist completed conversation to database (MongoDB or PostgreSQL). + + Retrieves all message chunks from memory store and saves the complete + conversation to the configured database for permanent storage. + + Args: + thread_id: Unique identifier for the conversation thread + store_namespace: Namespace tuple for accessing stored messages + final_index: The final chunk index for this conversation + + Returns: + bool: True if persistence was successful, False otherwise + """ + try: + # Retrieve all message chunks from memory store + # Get all messages up to the final index including cursor metadata + memories = self.store.search(store_namespace, limit=final_index + 2) + + # Extract message content, filtering out cursor metadata + messages: List[str] = [] + for item in memories: + value = item.dict().get("value", "") + # Skip cursor metadata, only include actual message chunks + if value and not isinstance(value, dict): + messages.append(str(value)) + + if not messages: + self.logger.warning(f"No messages found for thread {thread_id}") + return False + + if not self.checkpoint_saver: + self.logger.warning("Checkpoint saver is disabled") + return False + + # Choose persistence method based on available connection + if self.mongo_db is not None: + return self._persist_to_mongodb(thread_id, messages) + elif self.postgres_conn is not None: + return self._persist_to_postgresql(thread_id, messages) + else: + self.logger.warning("No database connection available") + return False + + except Exception as e: + self.logger.error( + f"Error persisting conversation for thread {thread_id}: {e}" + ) + return False + + def _persist_to_mongodb(self, thread_id: str, messages: List[str]) -> bool: + """Persist conversation to MongoDB.""" + try: + # Get MongoDB collection for chat streams + collection = self.mongo_db.chat_streams + + # Check if conversation already exists in database + existing_document = collection.find_one({"thread_id": thread_id}) + + current_timestamp = datetime.now() + + if existing_document: + # Update existing conversation with new messages + update_result = collection.update_one( + {"thread_id": thread_id}, + {"$set": {"messages": messages, "ts": current_timestamp}}, + ) + self.logger.info( + f"Updated conversation for thread {thread_id}: " + f"{update_result.modified_count} documents modified" + ) + return update_result.modified_count > 0 + else: + # Create new conversation document + new_document = { + "thread_id": thread_id, + "messages": messages, + "ts": current_timestamp, + "id": uuid.uuid4().hex, + } + insert_result = collection.insert_one(new_document) + self.logger.info( + f"Created new conversation: {insert_result.inserted_id}" + ) + return insert_result.inserted_id is not None + + except Exception as e: + self.logger.error(f"Error persisting to MongoDB: {e}") + return False + + def _persist_to_postgresql(self, thread_id: str, messages: List[str]) -> bool: + """Persist conversation to PostgreSQL.""" + try: + with self.postgres_conn.cursor() as cursor: + # Check if conversation already exists + cursor.execute( + "SELECT id FROM chat_streams WHERE thread_id = %s", (thread_id,) + ) + existing_record = cursor.fetchone() + + current_timestamp = datetime.now() + messages_json = json.dumps(messages) + + if existing_record: + # Update existing conversation with new messages + cursor.execute( + """ + UPDATE chat_streams + SET messages = %s, ts = %s + WHERE thread_id = %s + """, + (messages_json, current_timestamp, thread_id), + ) + affected_rows = cursor.rowcount + self.postgres_conn.commit() + + self.logger.info( + f"Updated conversation for thread {thread_id}: " + f"{affected_rows} rows modified" + ) + return affected_rows > 0 + else: + # Create new conversation record + conversation_id = uuid.uuid4() + cursor.execute( + """ + INSERT INTO chat_streams (id, thread_id, messages, ts) + VALUES (%s, %s, %s, %s) + """, + (conversation_id, thread_id, messages_json, current_timestamp), + ) + affected_rows = cursor.rowcount + self.postgres_conn.commit() + + self.logger.info( + f"Created new conversation with ID: {conversation_id}" + ) + return affected_rows > 0 + + except Exception as e: + self.logger.error(f"Error persisting to PostgreSQL: {e}") + if self.postgres_conn: + self.postgres_conn.rollback() + return False + + def close(self) -> None: + """Close database connections.""" + try: + if self.mongo_client is not None: + self.mongo_client.close() + self.logger.info("MongoDB connection closed") + except Exception as e: + self.logger.error(f"Error closing MongoDB connection: {e}") + + try: + if self.postgres_conn is not None: + self.postgres_conn.close() + self.logger.info("PostgreSQL connection closed") + except Exception as e: + self.logger.error(f"Error closing PostgreSQL connection: {e}") + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - close connections.""" + self.close() + + +# Global instance for backward compatibility +# TODO: Consider using dependency injection instead of global instance +_default_manager = ChatStreamManager( + checkpoint_saver=get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False), + db_uri=get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "mongodb://localhost:27017"), +) + + +def chat_stream_message(thread_id: str, message: str, finish_reason: str) -> bool: + """ + Legacy function wrapper for backward compatibility. + + Args: + thread_id: Unique identifier for the conversation thread + message: The message content to store + finish_reason: Reason for message completion + + Returns: + bool: True if message was processed successfully + """ + checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False) + if checkpoint_saver: + return _default_manager.process_stream_message( + thread_id, message, finish_reason + ) + else: + logging.warning("Checkpoint saver is disabled, message not processed") + return False diff --git a/src/server/app.py b/src/server/app.py index 860cbef..efea5b8 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -4,7 +4,6 @@ import base64 import json import logging -import os from typing import Annotated, List, cast from uuid import uuid4 @@ -13,8 +12,12 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response, StreamingResponse from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage from langgraph.types import Command +from langgraph.store.memory import InMemoryStore +from langgraph.checkpoint.mongodb import AsyncMongoDBSaver +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from psycopg_pool import AsyncConnectionPool -from src.config.configuration import get_recursion_limit +from src.config.configuration import get_recursion_limit, get_bool_env, get_str_env from src.config.report_style import ReportStyle from src.config.tools import SELECTED_RAG_PROVIDER from src.graph.builder import build_graph_with_memory @@ -42,6 +45,8 @@ from src.server.rag_request import ( RAGResourcesResponse, ) from src.tools import VolcengineTTS +from src.graph.checkpoint import chat_stream_message +from src.utils.json_utils import sanitize_args logger = logging.getLogger(__name__) @@ -56,7 +61,7 @@ app = FastAPI( # Add CORS middleware # It's recommended to load the allowed origins from an environment variable # for better security and flexibility across different environments. -allowed_origins_str = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000") +allowed_origins_str = get_str_env("ALLOWED_ORIGINS", "http://localhost:3000") allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",")] logger.info(f"Allowed origins: {allowed_origins}") @@ -68,18 +73,14 @@ app.add_middleware( allow_methods=["GET", "POST", "OPTIONS"], # Use the configured list of methods allow_headers=["*"], # Now allow all headers, but can be restricted further ) - +in_memory_store = InMemoryStore() graph = build_graph_with_memory() @app.post("/api/chat/stream") async def chat_stream(request: ChatRequest): # Check if MCP server configuration is enabled - mcp_enabled = os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() in [ - "true", - "1", - "yes", - ] + mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False) # Validate MCP settings if provided if request.mcp_settings and not mcp_enabled: @@ -91,6 +92,7 @@ async def chat_stream(request: ChatRequest): thread_id = request.thread_id if thread_id == "__default__": thread_id = str(uuid4()) + return StreamingResponse( _astream_workflow_generator( request.model_dump()["messages"], @@ -110,6 +112,154 @@ async def chat_stream(request: ChatRequest): ) +def _process_tool_call_chunks(tool_call_chunks): + """Process tool call chunks and sanitize arguments.""" + chunks = [] + for chunk in tool_call_chunks: + chunks.append( + { + "name": chunk.get("name", ""), + "args": sanitize_args(chunk.get("args", "")), + "id": chunk.get("id", ""), + "index": chunk.get("index", 0), + "type": chunk.get("type", ""), + } + ) + return chunks + + +def _get_agent_name(agent, message_metadata): + """Extract agent name from agent tuple.""" + agent_name = "unknown" + if agent and len(agent) > 0: + agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0] + else: + agent_name = message_metadata.get("langgraph_node", "unknown") + return agent_name + + +def _create_event_stream_message( + message_chunk, message_metadata, thread_id, agent_name +): + """Create base event stream message.""" + event_stream_message = { + "thread_id": thread_id, + "agent": agent_name, + "id": message_chunk.id, + "role": "assistant", + "checkpoint_ns": message_metadata.get("checkpoint_ns", ""), + "langgraph_node": message_metadata.get("langgraph_node", ""), + "langgraph_path": message_metadata.get("langgraph_path", ""), + "langgraph_step": message_metadata.get("langgraph_step", ""), + "content": message_chunk.content, + } + + # Add optional fields + if message_chunk.additional_kwargs.get("reasoning_content"): + event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[ + "reasoning_content" + ] + + if message_chunk.response_metadata.get("finish_reason"): + event_stream_message["finish_reason"] = message_chunk.response_metadata.get( + "finish_reason" + ) + + return event_stream_message + + +def _create_interrupt_event(thread_id, event_data): + """Create interrupt event.""" + return _make_event( + "interrupt", + { + "thread_id": thread_id, + "id": event_data["__interrupt__"][0].ns[0], + "role": "assistant", + "content": event_data["__interrupt__"][0].value, + "finish_reason": "interrupt", + "options": [ + {"text": "Edit plan", "value": "edit_plan"}, + {"text": "Start research", "value": "accepted"}, + ], + }, + ) + + +def _process_initial_messages(message, thread_id): + """Process initial messages and yield formatted events.""" + json_data = json.dumps( + { + "thread_id": thread_id, + "id": "run--" + message.get("id", uuid4().hex), + "role": "user", + "content": message.get("content", ""), + }, + ensure_ascii=False, + separators=(",", ":"), + ) + chat_stream_message( + thread_id, f"event: message_chunk\ndata: {json_data}\n\n", "none" + ) + + +async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent): + """Process a single message chunk and yield appropriate events.""" + agent_name = _get_agent_name(agent, message_metadata) + event_stream_message = _create_event_stream_message( + message_chunk, message_metadata, thread_id, agent_name + ) + + if isinstance(message_chunk, ToolMessage): + # Tool Message - Return the result of the tool call + event_stream_message["tool_call_id"] = message_chunk.tool_call_id + yield _make_event("tool_call_result", event_stream_message) + elif isinstance(message_chunk, AIMessageChunk): + # AI Message - Raw message tokens + if message_chunk.tool_calls: + # AI Message - Tool Call + event_stream_message["tool_calls"] = message_chunk.tool_calls + event_stream_message["tool_call_chunks"] = message_chunk.tool_call_chunks + event_stream_message["tool_call_chunks"] = _process_tool_call_chunks( + message_chunk.tool_call_chunks + ) + yield _make_event("tool_calls", event_stream_message) + elif message_chunk.tool_call_chunks: + # AI Message - Tool Call Chunks + event_stream_message["tool_call_chunks"] = _process_tool_call_chunks( + message_chunk.tool_call_chunks + ) + yield _make_event("tool_call_chunks", event_stream_message) + else: + # AI Message - Raw message tokens + yield _make_event("message_chunk", event_stream_message) + + +async def _stream_graph_events( + graph_instance, workflow_input, workflow_config, thread_id +): + """Stream events from the graph and process them.""" + async for agent, _, event_data in graph_instance.astream( + workflow_input, + config=workflow_config, + stream_mode=["messages", "updates"], + subgraphs=True, + ): + if isinstance(event_data, dict): + if "__interrupt__" in event_data: + yield _create_interrupt_event(thread_id, event_data) + continue + + message_chunk, message_metadata = cast( + tuple[BaseMessage, dict[str, any]], event_data + ) + + async for event in _process_message_chunk( + message_chunk, message_metadata, thread_id, agent + ): + yield event + + async def _astream_workflow_generator( messages: List[dict], thread_id: str, @@ -124,7 +274,13 @@ async def _astream_workflow_generator( report_style: ReportStyle, enable_deep_thinking: bool, ): - input_ = { + # Process initial messages + for message in messages: + if isinstance(message, dict) and "content" in message: + _process_initial_messages(message, thread_id) + + # Prepare workflow input + workflow_input = { "messages": messages, "plan_iterations": 0, "final_report": "", @@ -134,112 +290,105 @@ async def _astream_workflow_generator( "enable_background_investigation": enable_background_investigation, "research_topic": messages[-1]["content"] if messages else "", } + if not auto_accepted_plan and interrupt_feedback: resume_msg = f"[{interrupt_feedback}]" - # add the last message to the resume message if messages: resume_msg += f" {messages[-1]['content']}" - input_ = Command(resume=resume_msg) - async for agent, _, event_data in graph.astream( - input_, - config={ - "thread_id": thread_id, - "resources": resources, - "max_plan_iterations": max_plan_iterations, - "max_step_num": max_step_num, - "max_search_results": max_search_results, - "mcp_settings": mcp_settings, - "report_style": report_style.value, - "enable_deep_thinking": enable_deep_thinking, - "recursion_limit": get_recursion_limit(), - }, - stream_mode=["messages", "updates"], - subgraphs=True, - ): - if isinstance(event_data, dict): - if "__interrupt__" in event_data: - yield _make_event( - "interrupt", - { - "thread_id": thread_id, - "id": event_data["__interrupt__"][0].ns[0], - "role": "assistant", - "content": event_data["__interrupt__"][0].value, - "finish_reason": "interrupt", - "options": [ - {"text": "Edit plan", "value": "edit_plan"}, - {"text": "Start research", "value": "accepted"}, - ], - }, - ) - continue - message_chunk, message_metadata = cast( - tuple[BaseMessage, dict[str, any]], event_data - ) - # Handle empty agent tuple gracefully - agent_name = "planner" - if agent and len(agent) > 0: - agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0] - event_stream_message: dict[str, any] = { - "thread_id": thread_id, - "agent": agent_name, - "id": message_chunk.id, - "role": "assistant", - "content": message_chunk.content, - } - if message_chunk.additional_kwargs.get("reasoning_content"): - event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[ - "reasoning_content" - ] - if message_chunk.response_metadata.get("finish_reason"): - event_stream_message["finish_reason"] = message_chunk.response_metadata.get( - "finish_reason" - ) - if isinstance(message_chunk, ToolMessage): - # Tool Message - Return the result of the tool call - event_stream_message["tool_call_id"] = message_chunk.tool_call_id - yield _make_event("tool_call_result", event_stream_message) - elif isinstance(message_chunk, AIMessageChunk): - # AI Message - Raw message tokens - if message_chunk.tool_calls: - # AI Message - Tool Call - event_stream_message["tool_calls"] = message_chunk.tool_calls - event_stream_message["tool_call_chunks"] = ( - message_chunk.tool_call_chunks - ) - yield _make_event("tool_calls", event_stream_message) - elif message_chunk.tool_call_chunks: - # AI Message - Tool Call Chunks - event_stream_message["tool_call_chunks"] = ( - message_chunk.tool_call_chunks - ) - yield _make_event("tool_call_chunks", event_stream_message) - else: - # AI Message - Raw message tokens - yield _make_event("message_chunk", event_stream_message) + workflow_input = Command(resume=resume_msg) + + # Prepare workflow config + workflow_config = { + "thread_id": thread_id, + "resources": resources, + "max_plan_iterations": max_plan_iterations, + "max_step_num": max_step_num, + "max_search_results": max_search_results, + "mcp_settings": mcp_settings, + "report_style": report_style.value, + "enable_deep_thinking": enable_deep_thinking, + "recursion_limit": get_recursion_limit(), + } + + checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False) + checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "") + # Handle checkpointer if configured + connection_kwargs = { + "autocommit": True, + "row_factory": "dict_row", + "prepare_threshold": 0, + } + if checkpoint_saver and checkpoint_url != "": + if checkpoint_url.startswith("postgresql://"): + logger.info("start async postgres checkpointer.") + async with AsyncConnectionPool( + checkpoint_url, kwargs=connection_kwargs + ) as conn: + checkpointer = AsyncPostgresSaver(conn) + await checkpointer.setup() + graph.checkpointer = checkpointer + graph.store = in_memory_store + async for event in _stream_graph_events( + graph, workflow_input, workflow_config, thread_id + ): + yield event + + if checkpoint_url.startswith("mongodb://"): + logger.info("start async mongodb checkpointer.") + async with AsyncMongoDBSaver.from_conn_string( + checkpoint_url + ) as checkpointer: + graph.checkpointer = checkpointer + graph.store = in_memory_store + async for event in _stream_graph_events( + graph, workflow_input, workflow_config, thread_id + ): + yield event + else: + # Use graph without MongoDB checkpointer + async for event in _stream_graph_events( + graph, workflow_input, workflow_config, thread_id + ): + yield event def _make_event(event_type: str, data: dict[str, any]): if data.get("content") == "": data.pop("content") - return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + # Ensure JSON serialization with proper encoding + try: + json_data = json.dumps(data, ensure_ascii=False) + + finish_reason = data.get("finish_reason", "") + chat_stream_message( + data.get("thread_id", ""), + f"event: {event_type}\ndata: {json_data}\n\n", + finish_reason, + ) + + return f"event: {event_type}\ndata: {json_data}\n\n" + except (TypeError, ValueError) as e: + logger.error(f"Error serializing event data: {e}") + # Return a safe error event + error_data = json.dumps({"error": "Serialization failed"}, ensure_ascii=False) + return f"event: error\ndata: {error_data}\n\n" @app.post("/api/tts") async def text_to_speech(request: TTSRequest): """Convert text to speech using volcengine TTS API.""" - app_id = os.getenv("VOLCENGINE_TTS_APPID", "") + app_id = get_str_env("VOLCENGINE_TTS_APPID", "") if not app_id: raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set") - access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "") + access_token = get_str_env("VOLCENGINE_TTS_ACCESS_TOKEN", "") if not access_token: raise HTTPException( status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set" ) try: - cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts") - voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming") + cluster = get_str_env("VOLCENGINE_TTS_CLUSTER", "volcano_tts") + voice_type = get_str_env("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming") tts_client = VolcengineTTS( appid=app_id, @@ -382,11 +531,7 @@ async def enhance_prompt(request: EnhancePromptRequest): async def mcp_server_metadata(request: MCPServerMetadataRequest): """Get information about an MCP server.""" # Check if MCP server configuration is enabled - if os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() not in [ - "true", - "1", - "yes", - ]: + if not get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False): raise HTTPException( status_code=403, detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.", diff --git a/src/utils/json_utils.py b/src/utils/json_utils.py index 2558473..946b526 100644 --- a/src/utils/json_utils.py +++ b/src/utils/json_utils.py @@ -3,11 +3,33 @@ import logging import json +from typing import Any import json_repair logger = logging.getLogger(__name__) +def sanitize_args(args: Any) -> str: + """ + Sanitize tool call arguments to prevent special character issues. + + Args: + args: Tool call arguments string + + Returns: + str: Sanitized arguments string + """ + if not isinstance(args, str): + return "" + else: + return ( + args.replace("[", "[") + .replace("]", "]") + .replace("{", "{") + .replace("}", "}") + ) + + def repair_json_output(content: str) -> str: """ Repair and normalize JSON output. diff --git a/tests/unit/checkpoint/test_checkpoint.py b/tests/unit/checkpoint/test_checkpoint.py new file mode 100644 index 0000000..2109a3c --- /dev/null +++ b/tests/unit/checkpoint/test_checkpoint.py @@ -0,0 +1,660 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import src.graph.checkpoint as checkpoint + +POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db" +MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin" + + +def test_with_local_postgres_db(): + """Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB.""" + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=POSTGRES_URL, + ) + assert manager.postgres_conn is not None + assert manager.mongo_client is None + + +def test_with_local_mongo_db(): + """Ensure the ChatStreamManager can be initialized with a local MongoDB.""" + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=MONGO_URL, + ) + assert manager.mongo_db is not None + assert manager.postgres_conn is None + + +def test_init_without_checkpoint_saver(): + """Manager should not create DB clients when checkpoint_saver is False.""" + manager = checkpoint.ChatStreamManager(checkpoint_saver=False) + assert manager.checkpoint_saver is False + # DB connections are not created when saver is disabled + assert manager.mongo_client is None + assert manager.postgres_conn is None + + +def test_process_stream_partial_buffer_postgres(monkeypatch): + """Partial chunks should be buffered; Postgres init is stubbed to no-op.""" + + # Patch Postgres init to no-op + def _no_pg(self): + self.postgres_conn = None + + monkeypatch.setattr( + checkpoint.ChatStreamManager, "_init_postgresql", _no_pg, raising=True + ) + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=POSTGRES_URL, + ) + result = manager.process_stream_message("t1", "hello", finish_reason="partial") + assert result is True + # Verify the chunk was stored in the in-memory store + items = manager.store.search(("messages", "t1"), limit=10) + values = [it.dict()["value"] for it in items] + assert "hello" in values + + +def test_process_stream_partial_buffer_mongo(monkeypatch): + """Partial chunks should be buffered; Mongo init is stubbed to no-op.""" + + # Patch Mongo init to no-op for speed + def _no_mongo(self): + self.mongo_client = None + self.mongo_db = None + + monkeypatch.setattr( + checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True + ) + + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=MONGO_URL, + ) + result = manager.process_stream_message("t2", "hello", finish_reason="partial") + assert result is True + # Verify the chunk was stored in the in-memory store + items = manager.store.search(("messages", "t2"), limit=10) + values = [it.dict()["value"] for it in items] + assert "hello" in values + + +def test_persist_postgresql_local_db(): + """Ensure that the ChatStreamManager can persist to a local PostgreSQL DB.""" + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=POSTGRES_URL, + ) + assert manager.postgres_conn is not None + assert manager.mongo_client is None + + # Simulate a message to persist + thread_id = "test_thread" + messages = ["This is a test message."] + result = manager._persist_to_postgresql(thread_id, messages) + assert result is True + # Simulate a message with existing thread + result = manager._persist_to_postgresql(thread_id, ["Another message."]) + assert result is True + + +def test_persist_postgresql_called_with_aggregated_chunks(monkeypatch): + """On 'stop', aggregated chunks should be passed to PostgreSQL persist method.""" + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=POSTGRES_URL, + ) + + assert ( + manager.process_stream_message("thd3", "Hello", finish_reason="partial") is True + ) + assert ( + manager.process_stream_message("thd3", " World", finish_reason="stop") is True + ) + + # Verify the messages were aggregated correctly + with manager.postgres_conn.cursor() as cursor: + # Check if conversation already exists + cursor.execute( + "SELECT messages FROM chat_streams WHERE thread_id = %s", ("thd3",) + ) + existing_record = cursor.fetchone() + assert existing_record is not None + assert existing_record["messages"] == ["Hello", " World"] + + +def test_persist_not_attempted_when_saver_disabled(): + """When saver disabled, stop should not persist and should return False.""" + manager = checkpoint.ChatStreamManager(checkpoint_saver=False) + # stop should try to persist, but saver disabled => returns False + assert manager.process_stream_message("t4", "hello", finish_reason="stop") is False + + +def test_persist_mongodb_local_db(): + """Ensure that the ChatStreamManager can persist to a local MongoDB.""" + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=MONGO_URL, + ) + assert manager.mongo_db is not None + assert manager.postgres_conn is None + # Simulate a message to persist + thread_id = "test_thread" + messages = ["This is a test message."] + result = manager._persist_to_mongodb(thread_id, messages) + assert result is True + # Simulate a message with existing thread + result = manager._persist_to_mongodb(thread_id, ["Another message."]) + assert result is True + + +def test_persist_mongodb_called_with_aggregated_chunks(monkeypatch): + """On 'stop', aggregated chunks should be passed to MongoDB persist method.""" + + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=MONGO_URL, + ) + + assert ( + manager.process_stream_message("thd5", "Hello", finish_reason="partial") is True + ) + assert ( + manager.process_stream_message("thd5", " World", finish_reason="stop") is True + ) + + # Verify the messages were aggregated correctly + collection = manager.mongo_db.chat_streams + existing_record = collection.find_one({"thread_id": "thd5"}) + assert existing_record is not None + assert existing_record["messages"] == ["Hello", " World"] + + +def test_invalid_inputs_return_false(monkeypatch): + """Empty thread_id or message should be rejected and return False.""" + + def _no_mongo(self): + self.mongo_client = None + self.mongo_db = None + + monkeypatch.setattr( + checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True + ) + + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=MONGO_URL, + ) + assert manager.process_stream_message("", "msg", finish_reason="partial") is False + assert manager.process_stream_message("tid", "", finish_reason="partial") is False + + +def test_unsupported_db_uri_scheme(): + """Manager should log warning for unsupported database URI schemes.""" + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, db_uri="redis://localhost:6379/0" + ) + # Should not have any database connections + assert manager.mongo_client is None + assert manager.postgres_conn is None + assert manager.mongo_db is None + + +def test_process_stream_with_interrupt_finish_reason(): + """Test that 'interrupt' finish_reason triggers persistence like 'stop'.""" + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=MONGO_URL, + ) + + # Add partial message + assert ( + manager.process_stream_message( + "int_test", "Interrupted", finish_reason="partial" + ) + is True + ) + # Interrupt should trigger persistence + assert ( + manager.process_stream_message( + "int_test", " message", finish_reason="interrupt" + ) + is True + ) + + +def test_postgresql_connection_failure(monkeypatch): + """Test PostgreSQL connection failure handling.""" + + def failing_connect(dsn, **kwargs): + raise RuntimeError("Connection failed") + + monkeypatch.setattr("psycopg.connect", failing_connect) + + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=POSTGRES_URL, + ) + # Should have no postgres connection on failure + assert manager.postgres_conn is None + + +def test_mongodb_ping_failure(monkeypatch): + """Test MongoDB ping failure during initialization.""" + + class FakeAdmin: + def command(self, name): + raise RuntimeError("Ping failed") + + class FakeClient: + def __init__(self, uri): + self.admin = FakeAdmin() + + monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri)) + + manager = checkpoint.ChatStreamManager( + checkpoint_saver=True, + db_uri=MONGO_URL, + ) + # Should not have mongo_db set on ping failure + assert getattr(manager, "mongo_db", None) is None + + +def test_store_namespace_consistency(): + """Test that store namespace is consistently used across methods.""" + manager = checkpoint.ChatStreamManager(checkpoint_saver=False) + + # Process a partial message + assert ( + manager.process_stream_message("ns_test", "chunk1", finish_reason="partial") + is True + ) + + # Verify cursor is stored correctly + cursor = manager.store.get(("messages", "ns_test"), "cursor") + assert cursor is not None + assert cursor.value["index"] == 0 + + # Add another chunk + assert ( + manager.process_stream_message("ns_test", "chunk2", finish_reason="partial") + is True + ) + + # Verify cursor is incremented + cursor = manager.store.get(("messages", "ns_test"), "cursor") + assert cursor.value["index"] == 1 + + +def test_cursor_initialization_edge_cases(): + """Test cursor handling edge cases.""" + manager = checkpoint.ChatStreamManager(checkpoint_saver=False) + + # Manually set a cursor with missing index + namespace = ("messages", "edge_test") + manager.store.put(namespace, "cursor", {}) # Missing 'index' key + + # Should handle missing index gracefully + result = manager.process_stream_message( + "edge_test", "test", finish_reason="partial" + ) + assert result is True + + # Should default to 0 and increment to 1 + cursor = manager.store.get(namespace, "cursor") + assert cursor.value["index"] == 1 + + +def test_multiple_threads_isolation(): + """Test that different thread_ids are properly isolated.""" + manager = checkpoint.ChatStreamManager(checkpoint_saver=False) + + # Process messages for different threads + assert ( + manager.process_stream_message("thread1", "msg1", finish_reason="partial") + is True + ) + assert ( + manager.process_stream_message("thread2", "msg2", finish_reason="partial") + is True + ) + assert ( + manager.process_stream_message("thread1", "msg3", finish_reason="partial") + is True + ) + + # Verify isolation + thread1_items = manager.store.search(("messages", "thread1"), limit=10) + thread2_items = manager.store.search(("messages", "thread2"), limit=10) + + thread1_values = [ + item.dict()["value"] + for item in thread1_items + if isinstance(item.dict()["value"], str) + ] + thread2_values = [ + item.dict()["value"] + for item in thread2_items + if isinstance(item.dict()["value"], str) + ] + + assert "msg1" in thread1_values + assert "msg3" in thread1_values + assert "msg2" in thread2_values + assert "msg1" not in thread2_values + assert "msg2" not in thread1_values + + +def test_mongodb_insert_and_update_paths(monkeypatch): + """Exercise MongoDB insert, update, and exception branches.""" + + # Fake Mongo classes + class FakeUpdateResult: + def __init__(self, modified_count): + self.modified_count = modified_count + + class FakeInsertResult: + def __init__(self, inserted_id): + self.inserted_id = inserted_id + + class FakeCollection: + def __init__(self, mode="insert_success"): + self.mode = mode + + def find_one(self, query): + if self.mode.startswith("insert"): + return None + return {"thread_id": query["thread_id"]} + + def update_one(self, q, s): + if self.mode == "update_success": + return FakeUpdateResult(1) + return FakeUpdateResult(0) + + def insert_one(self, doc): + if self.mode == "insert_success": + return FakeInsertResult("ok") + if self.mode == "insert_none": + return FakeInsertResult(None) + raise RuntimeError("boom") + + class FakeMongoDB: + def __init__(self, mode): + self.chat_streams = FakeCollection(mode) + + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL) + + # Insert success + manager.mongo_db = FakeMongoDB("insert_success") + assert manager._persist_to_mongodb("th1", ["a"]) is True + + # Insert returns None id => False + manager.mongo_db = FakeMongoDB("insert_none") + assert manager._persist_to_mongodb("th2", ["a"]) is False + + # Insert raises => False + manager.mongo_db = FakeMongoDB("insert_raise") + assert manager._persist_to_mongodb("th3", ["a"]) is False + + # Update success + manager.mongo_db = FakeMongoDB("update_success") + assert manager._persist_to_mongodb("th4", ["a"]) is True + + # Update modifies 0 => False + manager.mongo_db = FakeMongoDB("update_zero") + assert manager._persist_to_mongodb("th5", ["a"]) is False + + +def test_postgresql_insert_update_and_error_paths(): + """Exercise PostgreSQL update, insert, and error/rollback branches.""" + calls = {"executed": []} + + class FakeCursor: + def __init__(self, mode): + self.mode = mode + self.rowcount = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, sql, params=None): + calls["executed"].append(sql.strip().split()[0]) + if "SELECT" in sql: + if self.mode == "update": + self._fetch = {"id": "x"} + elif self.mode == "error": + raise RuntimeError("sql error") + else: + self._fetch = None + else: + # UPDATE or INSERT + self.rowcount = 1 + + def fetchone(self): + return getattr(self, "_fetch", None) + + class FakeConn: + def __init__(self, mode): + self.mode = mode + self.commit_called = False + self.rollback_called = False + + def cursor(self): + return FakeCursor(self.mode) + + def commit(self): + self.commit_called = True + + def rollback(self): + self.rollback_called = True + + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL) + + # Update path + manager.postgres_conn = FakeConn("update") + assert manager._persist_to_postgresql("t", ["m"]) is True + assert manager.postgres_conn.commit_called is True + + # Insert path + manager.postgres_conn = FakeConn("insert") + assert manager._persist_to_postgresql("t", ["m"]) is True + assert manager.postgres_conn.commit_called is True + + # Error path with rollback + manager.postgres_conn = FakeConn("error") + assert manager._persist_to_postgresql("t", ["m"]) is False + assert manager.postgres_conn.rollback_called is True + + +def test_create_chat_streams_table_success_and_error(): + """Ensure table creation commits on success and rolls back on failure.""" + + class FakeCursor: + def __init__(self, should_fail=False): + self.should_fail = should_fail + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, sql): + if self.should_fail: + raise RuntimeError("ddl fail") + + class FakeConn: + def __init__(self, should_fail=False): + self.should_fail = should_fail + self.commits = 0 + self.rollbacks = 0 + + def cursor(self): + return FakeCursor(self.should_fail) + + def commit(self): + self.commits += 1 + + def rollback(self): + self.rollbacks += 1 + + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL) + + # Success + manager.postgres_conn = FakeConn(False) + manager._create_chat_streams_table() + assert manager.postgres_conn.commits == 1 + + # Failure triggers rollback + manager.postgres_conn = FakeConn(True) + manager._create_chat_streams_table() + assert manager.postgres_conn.rollbacks == 1 + + +def test_close_closes_resources_and_handles_errors(): + """Close should gracefully handle both success and exceptions.""" + flags = {"mongo": 0, "pg": 0} + + class M: + def close(self): + flags["mongo"] += 1 + + class P: + def __init__(self, raise_on_close=False): + self.raise_on_close = raise_on_close + + def close(self): + if self.raise_on_close: + raise RuntimeError("close fail") + flags["pg"] += 1 + + manager = checkpoint.ChatStreamManager(checkpoint_saver=False) + manager.mongo_client = M() + manager.postgres_conn = P() + manager.close() + assert flags == {"mongo": 1, "pg": 1} + + # Trigger error branches (no raise escapes) + manager.mongo_client = None # skip mongo + manager.postgres_conn = P(True) + manager.close() # should handle exception gracefully + + +def test_context_manager_calls_close(monkeypatch): + """The context manager protocol should call close() on exit.""" + called = {"close": 0} + + def _noop(self): + self.mongo_client = None + self.mongo_db = None + + monkeypatch.setattr( + checkpoint.ChatStreamManager, "_init_mongodb", _noop, raising=True + ) + + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL) + + def fake_close(): + called["close"] += 1 + + manager.close = fake_close + with manager: + pass + assert called["close"] == 1 + + +def test_init_mongodb_success_and_failure(monkeypatch): + """MongoDB init should succeed with a valid client and fail gracefully otherwise.""" + + class FakeAdmin: + def command(self, name): + assert name == "ping" + + class DummyDB: + pass + + class FakeClient: + def __init__(self, uri): + self.uri = uri + self.admin = FakeAdmin() + self.checkpointing_db = DummyDB() + + def close(self): + pass + + # Success path + monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri)) + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL) + assert manager.mongo_db is not None + + # Failure path + class Boom: + def __init__(self, uri): + raise RuntimeError("fail connect") + + monkeypatch.setattr(checkpoint, "MongoClient", Boom) + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL) + # Should have no mongo_db set on failure + assert getattr(manager, "mongo_db", None) is None + + +def test_init_postgresql_calls_connect_and_create_table(monkeypatch): + """PostgreSQL init should connect and create the required table.""" + flags = {"connected": 0, "created": 0} + + class FakeConn: + def __init__(self): + pass + + def close(self): + pass + + def fake_connect(self): + flags["connected"] += 1 + flags["created"] += 1 + return FakeConn() + + monkeypatch.setattr( + checkpoint.ChatStreamManager, "_init_postgresql", fake_connect, raising=True + ) + + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL) + assert manager.postgres_conn is None + assert flags == {"connected": 1, "created": 1} + + +def test_chat_stream_message_wrapper(monkeypatch): + """Wrapper should delegate when enabled and return False when disabled.""" + # When saver enabled, should call default manager + monkeypatch.setattr( + checkpoint, "get_bool_env", lambda k, d=False: True, raising=True + ) + + called = {"args": None} + + def fake_process(tid, msg, fr): + called["args"] = (tid, msg, fr) + return True + + monkeypatch.setattr( + checkpoint._default_manager, + "process_stream_message", + fake_process, + raising=True, + ) + assert checkpoint.chat_stream_message("tid", "msg", "stop") is True + assert called["args"] == ("tid", "msg", "stop") + + # When saver disabled, returns False and does not call manager + monkeypatch.setattr( + checkpoint, "get_bool_env", lambda k, d=False: False, raising=True + ) + called["args"] = None + assert checkpoint.chat_stream_message("tid", "msg", "stop") is False + assert called["args"] is None diff --git a/uv.lock b/uv.lock index 3c6e58f..e01e8c2 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.13'", @@ -383,6 +383,8 @@ dependencies = [ { name = "langchain-mcp-adapters" }, { name = "langchain-openai" }, { name = "langgraph" }, + { name = "langgraph-checkpoint-mongodb" }, + { name = "langgraph-checkpoint-postgres" }, { name = "litellm" }, { name = "markdownify" }, { name = "mcp" }, @@ -425,6 +427,8 @@ requires-dist = [ { name = "langchain-mcp-adapters", specifier = ">=0.0.9" }, { name = "langchain-openai", specifier = ">=0.3.8" }, { name = "langgraph", specifier = ">=0.3.5" }, + { name = "langgraph-checkpoint-mongodb", specifier = ">=0.1.4" }, + { name = "langgraph-checkpoint-postgres", specifier = "==2.0.21" }, { name = "langgraph-cli", extras = ["inmem"], marker = "extra == 'dev'", specifier = ">=0.2.10" }, { name = "litellm", specifier = ">=1.63.11" }, { name = "markdownify", specifier = ">=1.1.0" }, @@ -454,6 +458,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "dnspython" +version = "2.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/4a/263763cb2ba3816dd94b08ad3a33d5fdae34ecb856678773cc40a3605829/dnspython-2.7.0.tar.gz", hash = "sha256:ce9c432eda0dc91cf618a5cedf1a4e142651196bbcd2c80e89ed5a907e5cfaf1", size = 345197, upload-time = "2024-10-05T20:14:59.362Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632, upload-time = "2024-10-05T20:14:57.687Z" }, +] + [[package]] name = "duckduckgo-search" version = "8.0.0" @@ -946,6 +959,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/eb/9e98822d3db22beff44449a8f61fca208d4f59d592a7ce67ce4c400b8f8f/langchain_mcp_adapters-0.1.9-py3-none-any.whl", hash = "sha256:fd131009c60c9e5a864f96576bbe757fc1809abd604891cb2e5d6e8aebd6975c", size = 15300, upload-time = "2025-07-09T15:56:13.316Z" }, ] +[[package]] +name = "langchain-mongodb" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain" }, + { name = "langchain-core" }, + { name = "langchain-text-splitters" }, + { name = "lark" }, + { name = "numpy" }, + { name = "pymongo" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/02/82/ea0cfd092843f09c71d51d9b1fc3a5051fa477fb5d94b95e3b7bf73d6fd2/langchain_mongodb-0.6.2.tar.gz", hash = "sha256:fae221017b5db8a239837b2d163cf6493de6072e217239a95e3aa1fc3303615e", size = 237734, upload-time = "2025-05-12T14:37:51.964Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/fa/d47c1a3dd7ff07709023ff75284544a0d315c51e6de69ac01ff90a642dfe/langchain_mongodb-0.6.2-py3-none-any.whl", hash = "sha256:17f740e16582b8b6b241e625fb5c0ae273af1ccb13f2877ade1f1f14049e51a0", size = 59133, upload-time = "2025-05-12T14:37:50.603Z" }, +] + [[package]] name = "langchain-openai" version = "0.3.22" @@ -1032,6 +1062,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/52/bceb5b5348c7a60ef0625ab0a0a0a9ff5d78f0e12aed8cc55c49d5e8a8c9/langgraph_checkpoint-2.0.25-py3-none-any.whl", hash = "sha256:23416a0f5bc9dd712ac10918fc13e8c9c4530c419d2985a441df71a38fc81602", size = 42312, upload-time = "2025-04-26T21:00:42.242Z" }, ] +[[package]] +name = "langgraph-checkpoint-mongodb" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-mongodb" }, + { name = "langgraph-checkpoint" }, + { name = "motor" }, + { name = "pymongo" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/c8/062bef92e96a36ea14658f16c43b87892a98c65137b4e088044790960e91/langgraph_checkpoint_mongodb-0.1.4.tar.gz", hash = "sha256:cae9a63a80d8259388b23e941438b7ae56e20570c1f39f640ccb9f28f77a67fe", size = 144572, upload-time = "2025-06-13T20:20:06.563Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/9b/b78c4366cf21bb908f90e348999b9353f85d7c2e40a4e169bb9ebe6ff37a/langgraph_checkpoint_mongodb-0.1.4-py3-none-any.whl", hash = "sha256:5edfa15e0fc03c27b7dda840a001bd210f16abc75b25b64405bf901ca655fdb5", size = 11226, upload-time = "2025-06-13T20:20:05.855Z" }, +] + +[[package]] +name = "langgraph-checkpoint-postgres" +version = "2.0.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langgraph-checkpoint" }, + { name = "orjson" }, + { name = "psycopg" }, + { name = "psycopg-pool" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/51/18138b116807180c97093890dd7955f0e68eabcba97d323a61275bab45b6/langgraph_checkpoint_postgres-2.0.21.tar.gz", hash = "sha256:921915fd3de534b4c84469f93d03046c1ef1f224e44629212b172ec3e9b72ded", size = 30371, upload-time = "2025-04-18T16:31:50.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/31/d5f4a7dd63dddfdb85209a3cbc1778b14bc0dddadb431e34938956f45e8c/langgraph_checkpoint_postgres-2.0.21-py3-none-any.whl", hash = "sha256:f0a50f2c1496778e00ea888415521bb2b7789a12052aa5ae54d82cf517b271e8", size = 39440, upload-time = "2025-04-18T16:31:48.838Z" }, +] + [[package]] name = "langgraph-cli" version = "0.2.10" @@ -1112,6 +1172,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/f4/c206c0888f8a506404cb4f16ad89593bdc2f70cf00de26a1a0a7a76ad7a3/langsmith-0.3.45-py3-none-any.whl", hash = "sha256:5b55f0518601fa65f3bb6b1a3100379a96aa7b3ed5e9380581615ba9c65ed8ed", size = 363002, upload-time = "2025-06-05T05:10:27.228Z" }, ] +[[package]] +name = "lark" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/60/bc7622aefb2aee1c0b4ba23c1446d3e30225c8770b38d7aedbfb65ca9d5a/lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80", size = 252132, upload-time = "2024-08-13T19:49:00.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/00/d90b10b962b4277f5e64a78b6609968859ff86889f5b898c1a778c06ec00/lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c", size = 111036, upload-time = "2024-08-13T19:48:58.603Z" }, +] + [[package]] name = "litellm" version = "1.63.11" @@ -1261,6 +1330,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2f/cf/3fd38cfe43962452e4bfadc6966b2ea0afaf8e0286cb3991c247c8c33ebd/mcp-1.12.2-py3-none-any.whl", hash = "sha256:b86d584bb60193a42bd78aef01882c5c42d614e416cbf0480149839377ab5a5f", size = 158473, upload-time = "2025-07-24T18:29:03.419Z" }, ] +[[package]] +name = "motor" +version = "3.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pymongo" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/ae/96b88362d6a84cb372f7977750ac2a8aed7b2053eed260615df08d5c84f4/motor-3.7.1.tar.gz", hash = "sha256:27b4d46625c87928f331a6ca9d7c51c2f518ba0e270939d395bc1ddc89d64526", size = 280997, upload-time = "2025-05-14T18:56:33.653Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/9a/35e053d4f442addf751ed20e0e922476508ee580786546d699b0567c4c67/motor-3.7.1-py3-none-any.whl", hash = "sha256:8a63b9049e38eeeb56b4fdd57c3312a6d1f25d01db717fe7d82222393c410298", size = 74996, upload-time = "2025-05-14T18:56:31.665Z" }, +] + [[package]] name = "multidict" version = "6.1.0" @@ -1603,6 +1684,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/35/6c4c6fc8774a9e3629cd750dc24a7a4fb090a25ccd5c3246d127b70f9e22/propcache-0.3.0-py3-none-any.whl", hash = "sha256:67dda3c7325691c2081510e92c561f465ba61b975f481735aefdfc845d2cd043", size = 12101, upload-time = "2025-02-20T19:03:27.202Z" }, ] +[[package]] +name = "psycopg" +version = "3.2.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/4a/93a6ab570a8d1a4ad171a1f4256e205ce48d828781312c0bbaff36380ecb/psycopg-3.2.9.tar.gz", hash = "sha256:2fbb46fcd17bc81f993f28c47f1ebea38d66ae97cc2dbc3cad73b37cefbff700", size = 158122, upload-time = "2025-05-13T16:11:15.533Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/b0/a73c195a56eb6b92e937a5ca58521a5c3346fb233345adc80fd3e2f542e2/psycopg-3.2.9-py3-none-any.whl", hash = "sha256:01a8dadccdaac2123c916208c96e06631641c0566b22005493f09663c7a8d3b6", size = 202705, upload-time = "2025-05-13T16:06:26.584Z" }, +] + +[[package]] +name = "psycopg-pool" +version = "3.2.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/13/1e7850bb2c69a63267c3dbf37387d3f71a00fd0e2fa55c5db14d64ba1af4/psycopg_pool-3.2.6.tar.gz", hash = "sha256:0f92a7817719517212fbfe2fd58b8c35c1850cdd2a80d36b581ba2085d9148e5", size = 29770, upload-time = "2025-02-26T12:03:47.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/fd/4feb52a55c1a4bd748f2acaed1903ab54a723c47f6d0242780f4d97104d4/psycopg_pool-3.2.6-py3-none-any.whl", hash = "sha256:5887318a9f6af906d041a0b1dc1c60f8f0dda8340c2572b74e10907b51ed5da7", size = 38252, upload-time = "2025-02-26T12:03:45.073Z" }, +] + [[package]] name = "pycparser" version = "2.22" @@ -1687,6 +1793,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, ] +[[package]] +name = "pymongo" +version = "4.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/27/3634b2e8d88ad210ee6edac69259c698aefed4a79f0f7356cd625d5c423c/pymongo-4.12.1.tar.gz", hash = "sha256:8921bac7f98cccb593d76c4d8eaa1447e7d537ba9a2a202973e92372a05bd1eb", size = 2165515, upload-time = "2025-04-29T18:46:23.62Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/dd/b684de28bfaf7e296538601c514d4613f98b77cfa1de323c7b160f4e04d0/pymongo-4.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a7b771aa2f0854ddf7861e8ce2365f29df9159393543d047e43d8475bc4b8813", size = 910797, upload-time = "2025-04-29T18:44:57.783Z" }, + { url = "https://files.pythonhosted.org/packages/e8/80/4fadd5400a4fbe57e7ea0349f132461d5dfc46c124937600f5044290d817/pymongo-4.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34fd8681b6fa6e1025dd1000004f6b81cbf1961f145b8c58bd15e3957976068d", size = 910489, upload-time = "2025-04-29T18:45:01.089Z" }, + { url = "https://files.pythonhosted.org/packages/4e/83/303be22944312cc28e3a357556d21971c388189bf90aebc79e752afa2452/pymongo-4.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981e19b8f1040247dee5f7879e45f640f7e21a4d87eabb19283ce5a2927dd2e7", size = 1689142, upload-time = "2025-04-29T18:45:03.008Z" }, + { url = "https://files.pythonhosted.org/packages/a4/67/f4e8506caf001ab9464df2562e3e022b7324e7c10a979ce1b55b006f2445/pymongo-4.12.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9a487dc1fe92736987a156325d3d9c66cbde6eac658b2875f5f222b6d82edca", size = 1753373, upload-time = "2025-04-29T18:45:04.874Z" }, + { url = "https://files.pythonhosted.org/packages/2e/7c/22d65c2a4e3e941b345b8cc164b3b53f2c1d0db581d4991817b6375ef507/pymongo-4.12.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1525051c13984365c4a9b88ee2d63009fae277921bc89a0d323b52c51f91cbac", size = 1722399, upload-time = "2025-04-29T18:45:06.726Z" }, + { url = "https://files.pythonhosted.org/packages/07/0d/32fd1ebafd0090510fb4820d175fe35d646e5b28c71ad9c36cb3ce554567/pymongo-4.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ad689e0e4f364809084f9e5888b2dcd6f0431b682a1c68f3fdf241e20e14475", size = 1692374, upload-time = "2025-04-29T18:45:08.552Z" }, + { url = "https://files.pythonhosted.org/packages/e3/9c/d7a30ce6b983c3955c225e3038dafb4f299281775323f58b378f2a7e6e59/pymongo-4.12.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f9b18abca210c2917041ab2a380c12f6ddd2810844f1d64afb39caf8a15425e", size = 1651490, upload-time = "2025-04-29T18:45:10.658Z" }, + { url = "https://files.pythonhosted.org/packages/29/b3/7902d73df1d088ec0c60c19ef4bd7894c6e6e4dfbfd7ab4ae4fbedc9427c/pymongo-4.12.1-cp312-cp312-win32.whl", hash = "sha256:d9d90fec041c6d695a639c26ca83577aa74383f5e3744fd7931537b208d5a1b5", size = 879521, upload-time = "2025-04-29T18:45:12.993Z" }, + { url = "https://files.pythonhosted.org/packages/8c/68/a17ff6472e6be12bae75f5d11db4e3dccc55e02dcd4e66cd87871790a20e/pymongo-4.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:d004b13e4f03d73a3ad38505ba84b61a2c8ba0a304f02fe1b27bfc986c244192", size = 897765, upload-time = "2025-04-29T18:45:15.296Z" }, + { url = "https://files.pythonhosted.org/packages/0c/4d/e6654f3ec6819980cbad77795ccf2275cd65d6df41375a22cdbbccef8416/pymongo-4.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:90de2b060d69c22658ada162a5380a0f88cb8c0149023241b9e379732bd36152", size = 965051, upload-time = "2025-04-29T18:45:17.516Z" }, + { url = "https://files.pythonhosted.org/packages/54/95/627a047c32789544a938abfd9311c914e622cb036ad16866e7e1b9b80239/pymongo-4.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:edf4e05331ac875d3b27b4654b74d81e44607af4aa7d6bcd4a31801ca164e6fd", size = 964732, upload-time = "2025-04-29T18:45:19.478Z" }, + { url = "https://files.pythonhosted.org/packages/8f/6d/7a604e3ab5399f8fe1ca88abdbf7e54ceb6cf03e64f68b2ed192d9a5eaf5/pymongo-4.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa7a817c9afb7b8775d98c469ddb3fe9c17daf53225394c1a74893cf45d3ade9", size = 1953037, upload-time = "2025-04-29T18:45:22.115Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d5/269388e7b0d02d35f55440baf1e0120320b6db1b555eaed7117d04b35402/pymongo-4.12.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d142ca531694e9324b3c9ba86c0e905c5f857599c4018a386c4dc02ca490fa", size = 2030467, upload-time = "2025-04-29T18:45:24.069Z" }, + { url = "https://files.pythonhosted.org/packages/4b/d0/04a6b48d6ca3fc2ff156185a3580799a748cf713239d6181e91234a663d3/pymongo-4.12.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5d4c0461f5cd84d9fe87d5a84b1bc16371c4dd64d56dcfe5e69b15c0545a5ac", size = 1994139, upload-time = "2025-04-29T18:45:26.215Z" }, + { url = "https://files.pythonhosted.org/packages/ad/65/0567052d52c0ac8aaa4baa700b39cdd1cf2481d2e59bd9817a3daf169ca0/pymongo-4.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43afd2f39182731ac9fb81bbc9439d539e4bd2eda72cdee829d2fa906a1c4d37", size = 1954947, upload-time = "2025-04-29T18:45:28.423Z" }, + { url = "https://files.pythonhosted.org/packages/c5/5b/db25747b288218dbdd97e9aeff6a3bfa3f872efb4ed06fa8bec67b2a121e/pymongo-4.12.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:827ac668c003da7b175b8e5f521850e2c182b4638a3dec96d97f0866d5508a1e", size = 1904374, upload-time = "2025-04-29T18:45:30.943Z" }, + { url = "https://files.pythonhosted.org/packages/fc/1e/6d0eb040c02ae655fafd63bd737e96d7e832eecfd0bd37074d0066f94a78/pymongo-4.12.1-cp313-cp313-win32.whl", hash = "sha256:7c2269b37f034124a245eaeb34ce031cee64610437bd597d4a883304babda3cd", size = 925869, upload-time = "2025-04-29T18:45:32.998Z" }, + { url = "https://files.pythonhosted.org/packages/59/b9/459da646d9750529f04e7e686f0cd8dd40174138826574885da334c01b16/pymongo-4.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:3b28ecd1305b89089be14f137ffbdf98a3b9f5c8dbbb2be4dec084f2813fbd5f", size = 948411, upload-time = "2025-04-29T18:45:35.445Z" }, + { url = "https://files.pythonhosted.org/packages/c9/c3/75be116159f210811656ec615b2248f63f1bc9dd1ce641e18db2552160f0/pymongo-4.12.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:f27b22a8215caff68bdf46b5b61ccd843a68334f2aa4658e8d5ecb5d3fbebb3b", size = 1021562, upload-time = "2025-04-29T18:45:37.433Z" }, + { url = "https://files.pythonhosted.org/packages/cd/d1/2e8e368cad1c126a68365a6f53feaade58f9a16bd5f7a69f218af119b0e9/pymongo-4.12.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5e9d23a3c290cf7409515466a7f11069b70e38ea2b786bbd7437bdc766c9e176", size = 1021553, upload-time = "2025-04-29T18:45:39.344Z" }, + { url = "https://files.pythonhosted.org/packages/17/6e/a6460bc1e3d3f5f46cc151417427b2687a6f87972fd68a33961a37c114df/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efeb430f7ca8649a6544a50caefead343d1fd096d04b6b6a002c6ce81148a85c", size = 2281736, upload-time = "2025-04-29T18:45:41.462Z" }, + { url = "https://files.pythonhosted.org/packages/1a/e2/9e1d6f1a492bb02116074baa832716805a0552d757c176e7c5f40867ca80/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a34e4a08bbcff56fdee86846afbc9ce751de95706ca189463e01bf5de3dd9927", size = 2368964, upload-time = "2025-04-29T18:45:43.579Z" }, + { url = "https://files.pythonhosted.org/packages/fa/df/88143016eca77e79e38cf072476c70dd360962934430447dabc9c6bef6df/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b063344e0282537f05dbb11147591cbf58fc09211e24fc374749e343f880910a", size = 2327834, upload-time = "2025-04-29T18:45:45.847Z" }, + { url = "https://files.pythonhosted.org/packages/3c/0d/df2998959b52cd5682b11e6eee1b0e0c104c07abd99c9cde5a871bb299fd/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3f7941e01b3e5d4bfb3b4711425e809df8c471b92d1da8d6fab92c7e334a4cb", size = 2279126, upload-time = "2025-04-29T18:45:48.445Z" }, + { url = "https://files.pythonhosted.org/packages/fb/3e/102636f5aaf97ccfa2a156c253a89f234856a0cd252fa602d4bf077ba3c0/pymongo-4.12.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b41235014031739f32be37ff13992f51091dae9a5189d3bcc22a5bf81fd90dae", size = 2218136, upload-time = "2025-04-29T18:45:50.57Z" }, + { url = "https://files.pythonhosted.org/packages/44/c9/1b534c9d8d91d9d98310f2d955c5331fb522bd2a0105bd1fc31771d53758/pymongo-4.12.1-cp313-cp313t-win32.whl", hash = "sha256:9a1f07fe83a8a34651257179bd38d0f87bd9d90577fcca23364145c5e8ba1bc0", size = 974747, upload-time = "2025-04-29T18:45:52.66Z" }, + { url = "https://files.pythonhosted.org/packages/08/e2/7d3a30ac905c99ea93729e03d2bb3d16fec26a789e98407d61cb368ab4bb/pymongo-4.12.1-cp313-cp313t-win_amd64.whl", hash = "sha256:46d86cf91ee9609d0713242a1d99fa9e9c60b4315e1a067b9a9e769bedae629d", size = 1003332, upload-time = "2025-04-29T18:45:54.631Z" }, +] + [[package]] name = "pytest" version = "8.3.5" diff --git a/web/src/core/messages/merge-message.ts b/web/src/core/messages/merge-message.ts index 6a4cabf..6877546 100644 --- a/web/src/core/messages/merge-message.ts +++ b/web/src/core/messages/merge-message.ts @@ -49,7 +49,11 @@ function mergeTextMessage(message: Message, event: MessageChunkEvent) { message.reasoningContentChunks.push(event.data.reasoning_content); } } - +function convertToolChunkArgs(args: string) { + // Convert escaped characters in args + if (!args) return ""; + return args.replace(/[/g, "[").replace(/]/g, "]").replace(/{/g, "{").replace(/}/g, "}"); +} function mergeToolCallMessage( message: Message, event: ToolCallsEvent | ToolCallChunksEvent, @@ -70,14 +74,14 @@ function mergeToolCallMessage( (toolCall) => toolCall.id === chunk.id, ); if (toolCall) { - toolCall.argsChunks = [chunk.args]; + toolCall.argsChunks = [convertToolChunkArgs(chunk.args)]; } } else { const streamingToolCall = message.toolCalls.find( (toolCall) => toolCall.argsChunks?.length, ); if (streamingToolCall) { - streamingToolCall.argsChunks!.push(chunk.args); + streamingToolCall.argsChunks!.push(convertToolChunkArgs(chunk.args)); } } }