feat: Enhance chat streaming and tool call processing (#498)

* feat: Enhance chat streaming and tool call processing

- Added support for MongoDB checkpointer in the chat streaming workflow.
- Introduced functions to process tool call chunks and sanitize arguments.
- Improved event message creation with additional metadata.
- Enhanced error handling for JSON serialization in event messages.
- Updated the frontend to convert escaped characters in tool call arguments.
- Refactored the workflow input preparation and initial message processing.
- Added new dependencies for MongoDB integration and tool argument sanitization.

* fix: Update MongoDB checkpointer configuration to use LANGGRAPH_CHECKPOINT_DB_URL

* feat: Add support for Postgres checkpointing and update README with database recommendations

* feat: Implement checkpoint saver functionality and update MongoDB connection handling

* refactor: Improve code formatting and readability in app.py and json_utils.py

* refactor: Clean up commented code and improve formatting in server.py

* refactor: Remove unused imports and improve code organization in app.py

* refactor: Improve code organization and remove unnecessary comments in app.py

* chore: use langgraph-checkpoint-postgres==2.0.21 to avoid the JSON convert issue in the latest version, implement chat stream persistant with Postgres

* feat: add MongoDB and PostgreSQL support for LangGraph checkpointing, enhance environment variable handling

* fix: update comments for clarity on Windows event loop policy

* chore: remove empty code changes in MongoDB and PostgreSQL checkpoint tests

* chore: clean up unused imports and code in checkpoint-related files

* chore: remove empty code changes in test_checkpoint.py

* chore: remove empty code changes in test_checkpoint.py

* chore: remove empty code changes in test_checkpoint.py

* test: update status code assertions in MCP endpoint tests to allow for 403 responses

* test: update MCP endpoint tests to assert specific status codes and enable MCP server configuration

* chore: remove unnecessary environment variables from unittest workflow

* fix: invert condition for MCP server configuration check to raise 403 when disabled

* chore: remove pymongo from test dependencies in uv.lock

* chore:  optimize the _get_agent_name method

* test: enhance ChatStreamManager tests for PostgreSQL and MongoDB initialization

* test: add persistence tests for ChatStreamManager with PostgreSQL and MongoDB

* test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB

* test: enhance persistence tests for ChatStreamManager with PostgreSQL and MongoDB to verify message aggregation

* test: add unit tests for ChatStreamManager with PostgreSQL and MongoDB

* test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB

* test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
CHANGXUBO
2025-08-16 21:03:12 +08:00
committed by GitHub
parent d65b8f8fcc
commit 1bfec3ad05
12 changed files with 1558 additions and 119 deletions

View File

@@ -55,3 +55,10 @@ VOLCENGINE_TTS_ACCESS_TOKEN=xxx
# [!NOTE] # [!NOTE]
# For model settings and other configurations, please refer to `docs/configuration_guide.md` # For model settings and other configurations, please refer to `docs/configuration_guide.md`
# Option, for langgraph mongodb checkpointer
# Enable LangGraph checkpoint saver, supports MongoDB, Postgres
#LANGGRAPH_CHECKPOINT_SAVER=true
# Set the database URL for saving checkpoints
#LANGGRAPH_CHECKPOINT_DB_URL="ongodb://localhost:27017/
#LANGGRAPH_CHECKPOINT_DB_URL=postgresql://localhost:5432/postgres

View File

@@ -12,6 +12,31 @@ permissions:
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
services:
postgres:
image: postgres:15
env:
POSTGRES_DB: checkpointing_db
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports: ["5432:5432"]
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
mongodb:
image: mongo:6
env:
MONGO_INITDB_ROOT_USERNAME: admin
MONGO_INITDB_ROOT_PASSWORD: admin
MONGO_INITDB_DATABASE: checkpointing_db
ports: ["27017:27017"]
options: >-
--health-cmd "mongosh --eval 'db.runCommand(\"ping\").ok'"
--health-interval 10s
--health-timeout 5s
--health-retries 3
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3

View File

@@ -386,6 +386,33 @@ DeerFlow supports LangSmith tracing to help you debug and monitor your workflows
This will enable trace visualization in LangGraph Studio and send your traces to LangSmith for monitoring and analysis. This will enable trace visualization in LangGraph Studio and send your traces to LangSmith for monitoring and analysis.
### Checkpointing
1. Postgres and MonogDB implementation of LangGraph checkpoint saver.
2. In-memory store is used to caching the streaming messages before persisting to database, If finish_reason is "stop" or "interrupt", it triggers persistence.
3. Supports saving and loading checkpoints for workflow execution.
4. Supports saving chat stream events for replaying conversations.
Note:
The latest langgraph-checkpoint-postgres-2.0.23 have checkpointing issue, you can check the open issue:"TypeError: Object of type HumanMessage is not JSON serializable" [https://github.com/langchain-ai/langgraph/issues/5557].
To use postgres checkpoint you should install langgraph-checkpoint-postgres-2.0.21
The default database and collection will be automatically created if not exists.
Default database: checkpoing_db
Default collection: checkpoint_writes_aio (langgraph checkpoint writes)
Default collection: checkpoints_aio (langgraph checkpoints)
Default collection: chat_streams (chat stream events for replaying conversations)
You need to set the following environment variables in your `.env` file:
```bash
# Enable LangGraph checkpoint saver, supports MongoDB, Postgres
LANGGRAPH_CHECKPOINT_SAVER=true
# Set the database URL for saving checkpoints
LANGGRAPH_CHECKPOINT_DB_URL="mongodb://localhost:27017/"
#LANGGRAPH_CHECKPOINT_DB_URL=postgresql://localhost:5432/postgres
```
## Docker ## Docker
You can also run this project with Docker. You can also run this project with Docker.

View File

@@ -34,6 +34,8 @@ dependencies = [
"langchain-mcp-adapters>=0.0.9", "langchain-mcp-adapters>=0.0.9",
"langchain-deepseek>=0.1.3", "langchain-deepseek>=0.1.3",
"wikipedia>=1.4.0", "wikipedia>=1.4.0",
"langgraph-checkpoint-mongodb>=0.1.4",
"langgraph-checkpoint-postgres==2.0.21",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -4,7 +4,8 @@
""" """
Server script for running the DeerFlow API. Server script for running the DeerFlow API.
""" """
import os
import asyncio
import argparse import argparse
import logging import logging
import signal import signal
@@ -19,6 +20,17 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# To ensure compatibility with Windows event loop issues when using Uvicorn and Asyncio Checkpointer,
# This is necessary because some libraries expect a selector-based event loop.
# This is a workaround for issues with Uvicorn and Watchdog on Windows.
# See:
# Since Python 3.8 the default on Windows is the Proactor event loop,
# which lacks add_reader/add_writer and can break libraries that expect selector-based I/O (e.g., some Uvicorn/Watchdog/stdio integrations).
# For compatibility, this forces the selector loop.
if os.name == "nt":
logger.info("Setting Windows event loop policy for asyncio")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def handle_shutdown(signum, frame): def handle_shutdown(signum, frame):
"""Handle graceful shutdown on SIGTERM/SIGINT""" """Handle graceful shutdown on SIGTERM/SIGINT"""

View File

@@ -13,6 +13,33 @@ from src.config.report_style import ReportStyle
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_TRUTHY = {"1", "true", "yes", "y", "on"}
def get_bool_env(name: str, default: bool = False) -> bool:
val = os.getenv(name)
if val is None:
return default
return str(val).strip().lower() in _TRUTHY
def get_str_env(name: str, default: str = "") -> str:
val = os.getenv(name)
return default if val is None else str(val).strip()
def get_int_env(name: str, default: int = 0) -> int:
val = os.getenv(name)
if val is None:
return default
try:
return int(val.strip())
except ValueError:
logger.warning(
f"Invalid integer value for {name}: {val}. Using default {default}."
)
return default
def get_recursion_limit(default: int = 25) -> int: def get_recursion_limit(default: int = 25) -> int:
"""Get the recursion limit from environment variable or use default. """Get the recursion limit from environment variable or use default.
@@ -23,23 +50,15 @@ def get_recursion_limit(default: int = 25) -> int:
Returns: Returns:
int: The recursion limit to use int: The recursion limit to use
""" """
try: env_value_str = get_str_env("AGENT_RECURSION_LIMIT", str(default))
env_value_str = os.getenv("AGENT_RECURSION_LIMIT", str(default)) parsed_limit = get_int_env("AGENT_RECURSION_LIMIT", default)
parsed_limit = int(env_value_str)
if parsed_limit > 0: if parsed_limit > 0:
logger.info(f"Recursion limit set to: {parsed_limit}") logger.info(f"Recursion limit set to: {parsed_limit}")
return parsed_limit return parsed_limit
else: else:
logger.warning(
f"AGENT_RECURSION_LIMIT value '{env_value_str}' (parsed as {parsed_limit}) is not positive. "
f"Using default value {default}."
)
return default
except ValueError:
raw_env_value = os.getenv("AGENT_RECURSION_LIMIT")
logger.warning( logger.warning(
f"Invalid AGENT_RECURSION_LIMIT value: '{raw_env_value}'. " f"AGENT_RECURSION_LIMIT value '{env_value_str}' (parsed as {parsed_limit}) is not positive. "
f"Using default value {default}." f"Using default value {default}."
) )
return default return default

372
src/graph/checkpoint.py Normal file
View File

@@ -0,0 +1,372 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
import logging
import uuid
from datetime import datetime
from typing import List, Optional, Tuple
import psycopg
from psycopg.rows import dict_row
from pymongo import MongoClient
from langgraph.store.memory import InMemoryStore
from src.config.configuration import get_bool_env, get_str_env
class ChatStreamManager:
"""
Manages chat stream messages with persistent storage and in-memory caching.
This class handles the storage and retrieval of chat messages using both
an in-memory store for temporary data and MongoDB or PostgreSQL for persistent storage.
It tracks message chunks and consolidates them when a conversation finishes.
Attributes:
store (InMemoryStore): In-memory storage for temporary message chunks
mongo_client (MongoClient): MongoDB client connection
mongo_db (Database): MongoDB database instance
postgres_conn (psycopg.Connection): PostgreSQL connection
logger (logging.Logger): Logger instance for this class
"""
def __init__(
self, checkpoint_saver: bool = False, db_uri: Optional[str] = None
) -> None:
"""
Initialize the ChatStreamManager with database connections.
Args:
db_uri: Database connection URI. Supports MongoDB (mongodb://) and PostgreSQL (postgresql://)
If None, uses LANGGRAPH_CHECKPOINT_DB_URL env var or defaults to localhost
"""
self.logger = logging.getLogger(__name__)
self.store = InMemoryStore()
self.checkpoint_saver = checkpoint_saver
# Use provided URI or fall back to environment variable or default
self.db_uri = db_uri
# Initialize database connections
self.mongo_client = None
self.mongo_db = None
self.postgres_conn = None
if self.checkpoint_saver:
if self.db_uri.startswith("mongodb://"):
self._init_mongodb()
elif self.db_uri.startswith("postgresql://") or self.db_uri.startswith(
"postgres://"
):
self._init_postgresql()
else:
self.logger.warning(
f"Unsupported database URI scheme: {self.db_uri}. "
"Supported schemes: mongodb://, postgresql://, postgres://"
)
else:
self.logger.warning("Checkpoint saver is disabled")
def _init_mongodb(self) -> None:
"""Initialize MongoDB connection."""
try:
self.mongo_client = MongoClient(self.db_uri)
self.mongo_db = self.mongo_client.checkpointing_db
# Test connection
self.mongo_client.admin.command("ping")
self.logger.info("Successfully connected to MongoDB")
except Exception as e:
self.logger.error(f"Failed to connect to MongoDB: {e}")
def _init_postgresql(self) -> None:
"""Initialize PostgreSQL connection and create table if needed."""
try:
self.postgres_conn = psycopg.connect(self.db_uri, row_factory=dict_row)
self.logger.info("Successfully connected to PostgreSQL")
self._create_chat_streams_table()
except Exception as e:
self.logger.error(f"Failed to connect to PostgreSQL: {e}")
def _create_chat_streams_table(self) -> None:
"""Create the chat_streams table if it doesn't exist."""
try:
with self.postgres_conn.cursor() as cursor:
create_table_sql = """
CREATE TABLE IF NOT EXISTS chat_streams (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
thread_id VARCHAR(255) NOT NULL UNIQUE,
messages JSONB NOT NULL,
ts TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_chat_streams_thread_id ON chat_streams(thread_id);
CREATE INDEX IF NOT EXISTS idx_chat_streams_ts ON chat_streams(ts);
"""
cursor.execute(create_table_sql)
self.postgres_conn.commit()
self.logger.info("Chat streams table created/verified successfully")
except Exception as e:
self.logger.error(f"Failed to create chat_streams table: {e}")
if self.postgres_conn:
self.postgres_conn.rollback()
def process_stream_message(
self, thread_id: str, message: str, finish_reason: str
) -> bool:
"""
Process and store a chat stream message chunk.
This method handles individual message chunks during streaming and consolidates
them into a complete message when the stream finishes. Messages are stored
temporarily in memory and permanently in MongoDB when complete.
Args:
thread_id: Unique identifier for the conversation thread
message: The message content or chunk to store
finish_reason: Reason for message completion ("stop", "interrupt", or partial)
Returns:
bool: True if message was processed successfully, False otherwise
"""
if not thread_id or not isinstance(thread_id, str):
self.logger.warning("Invalid thread_id provided")
return False
if not message:
self.logger.warning("Empty message provided")
return False
try:
# Create namespace for this thread's messages
store_namespace: Tuple[str, str] = ("messages", thread_id)
# Get or initialize message cursor for tracking chunks
cursor = self.store.get(store_namespace, "cursor")
current_index = 0
if cursor is None:
# Initialize cursor for new conversation
self.store.put(store_namespace, "cursor", {"index": 0})
else:
# Increment index for next chunk
current_index = int(cursor.value.get("index", 0)) + 1
self.store.put(store_namespace, "cursor", {"index": current_index})
# Store the current message chunk
self.store.put(store_namespace, f"chunk_{current_index}", message)
# Check if conversation is complete and should be persisted
if finish_reason in ("stop", "interrupt"):
return self._persist_complete_conversation(
thread_id, store_namespace, current_index
)
return True
except Exception as e:
self.logger.error(
f"Error processing stream message for thread {thread_id}: {e}"
)
return False
def _persist_complete_conversation(
self, thread_id: str, store_namespace: Tuple[str, str], final_index: int
) -> bool:
"""
Persist completed conversation to database (MongoDB or PostgreSQL).
Retrieves all message chunks from memory store and saves the complete
conversation to the configured database for permanent storage.
Args:
thread_id: Unique identifier for the conversation thread
store_namespace: Namespace tuple for accessing stored messages
final_index: The final chunk index for this conversation
Returns:
bool: True if persistence was successful, False otherwise
"""
try:
# Retrieve all message chunks from memory store
# Get all messages up to the final index including cursor metadata
memories = self.store.search(store_namespace, limit=final_index + 2)
# Extract message content, filtering out cursor metadata
messages: List[str] = []
for item in memories:
value = item.dict().get("value", "")
# Skip cursor metadata, only include actual message chunks
if value and not isinstance(value, dict):
messages.append(str(value))
if not messages:
self.logger.warning(f"No messages found for thread {thread_id}")
return False
if not self.checkpoint_saver:
self.logger.warning("Checkpoint saver is disabled")
return False
# Choose persistence method based on available connection
if self.mongo_db is not None:
return self._persist_to_mongodb(thread_id, messages)
elif self.postgres_conn is not None:
return self._persist_to_postgresql(thread_id, messages)
else:
self.logger.warning("No database connection available")
return False
except Exception as e:
self.logger.error(
f"Error persisting conversation for thread {thread_id}: {e}"
)
return False
def _persist_to_mongodb(self, thread_id: str, messages: List[str]) -> bool:
"""Persist conversation to MongoDB."""
try:
# Get MongoDB collection for chat streams
collection = self.mongo_db.chat_streams
# Check if conversation already exists in database
existing_document = collection.find_one({"thread_id": thread_id})
current_timestamp = datetime.now()
if existing_document:
# Update existing conversation with new messages
update_result = collection.update_one(
{"thread_id": thread_id},
{"$set": {"messages": messages, "ts": current_timestamp}},
)
self.logger.info(
f"Updated conversation for thread {thread_id}: "
f"{update_result.modified_count} documents modified"
)
return update_result.modified_count > 0
else:
# Create new conversation document
new_document = {
"thread_id": thread_id,
"messages": messages,
"ts": current_timestamp,
"id": uuid.uuid4().hex,
}
insert_result = collection.insert_one(new_document)
self.logger.info(
f"Created new conversation: {insert_result.inserted_id}"
)
return insert_result.inserted_id is not None
except Exception as e:
self.logger.error(f"Error persisting to MongoDB: {e}")
return False
def _persist_to_postgresql(self, thread_id: str, messages: List[str]) -> bool:
"""Persist conversation to PostgreSQL."""
try:
with self.postgres_conn.cursor() as cursor:
# Check if conversation already exists
cursor.execute(
"SELECT id FROM chat_streams WHERE thread_id = %s", (thread_id,)
)
existing_record = cursor.fetchone()
current_timestamp = datetime.now()
messages_json = json.dumps(messages)
if existing_record:
# Update existing conversation with new messages
cursor.execute(
"""
UPDATE chat_streams
SET messages = %s, ts = %s
WHERE thread_id = %s
""",
(messages_json, current_timestamp, thread_id),
)
affected_rows = cursor.rowcount
self.postgres_conn.commit()
self.logger.info(
f"Updated conversation for thread {thread_id}: "
f"{affected_rows} rows modified"
)
return affected_rows > 0
else:
# Create new conversation record
conversation_id = uuid.uuid4()
cursor.execute(
"""
INSERT INTO chat_streams (id, thread_id, messages, ts)
VALUES (%s, %s, %s, %s)
""",
(conversation_id, thread_id, messages_json, current_timestamp),
)
affected_rows = cursor.rowcount
self.postgres_conn.commit()
self.logger.info(
f"Created new conversation with ID: {conversation_id}"
)
return affected_rows > 0
except Exception as e:
self.logger.error(f"Error persisting to PostgreSQL: {e}")
if self.postgres_conn:
self.postgres_conn.rollback()
return False
def close(self) -> None:
"""Close database connections."""
try:
if self.mongo_client is not None:
self.mongo_client.close()
self.logger.info("MongoDB connection closed")
except Exception as e:
self.logger.error(f"Error closing MongoDB connection: {e}")
try:
if self.postgres_conn is not None:
self.postgres_conn.close()
self.logger.info("PostgreSQL connection closed")
except Exception as e:
self.logger.error(f"Error closing PostgreSQL connection: {e}")
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - close connections."""
self.close()
# Global instance for backward compatibility
# TODO: Consider using dependency injection instead of global instance
_default_manager = ChatStreamManager(
checkpoint_saver=get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False),
db_uri=get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "mongodb://localhost:27017"),
)
def chat_stream_message(thread_id: str, message: str, finish_reason: str) -> bool:
"""
Legacy function wrapper for backward compatibility.
Args:
thread_id: Unique identifier for the conversation thread
message: The message content to store
finish_reason: Reason for message completion
Returns:
bool: True if message was processed successfully
"""
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
if checkpoint_saver:
return _default_manager.process_stream_message(
thread_id, message, finish_reason
)
else:
logging.warning("Checkpoint saver is disabled, message not processed")
return False

View File

@@ -4,7 +4,6 @@
import base64 import base64
import json import json
import logging import logging
import os
from typing import Annotated, List, cast from typing import Annotated, List, cast
from uuid import uuid4 from uuid import uuid4
@@ -13,8 +12,12 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
from langgraph.types import Command from langgraph.types import Command
from langgraph.store.memory import InMemoryStore
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
from src.config.configuration import get_recursion_limit from src.config.configuration import get_recursion_limit, get_bool_env, get_str_env
from src.config.report_style import ReportStyle from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory from src.graph.builder import build_graph_with_memory
@@ -42,6 +45,8 @@ from src.server.rag_request import (
RAGResourcesResponse, RAGResourcesResponse,
) )
from src.tools import VolcengineTTS from src.tools import VolcengineTTS
from src.graph.checkpoint import chat_stream_message
from src.utils.json_utils import sanitize_args
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -56,7 +61,7 @@ app = FastAPI(
# Add CORS middleware # Add CORS middleware
# It's recommended to load the allowed origins from an environment variable # It's recommended to load the allowed origins from an environment variable
# for better security and flexibility across different environments. # for better security and flexibility across different environments.
allowed_origins_str = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000") allowed_origins_str = get_str_env("ALLOWED_ORIGINS", "http://localhost:3000")
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",")] allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",")]
logger.info(f"Allowed origins: {allowed_origins}") logger.info(f"Allowed origins: {allowed_origins}")
@@ -68,18 +73,14 @@ app.add_middleware(
allow_methods=["GET", "POST", "OPTIONS"], # Use the configured list of methods allow_methods=["GET", "POST", "OPTIONS"], # Use the configured list of methods
allow_headers=["*"], # Now allow all headers, but can be restricted further allow_headers=["*"], # Now allow all headers, but can be restricted further
) )
in_memory_store = InMemoryStore()
graph = build_graph_with_memory() graph = build_graph_with_memory()
@app.post("/api/chat/stream") @app.post("/api/chat/stream")
async def chat_stream(request: ChatRequest): async def chat_stream(request: ChatRequest):
# Check if MCP server configuration is enabled # Check if MCP server configuration is enabled
mcp_enabled = os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() in [ mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)
"true",
"1",
"yes",
]
# Validate MCP settings if provided # Validate MCP settings if provided
if request.mcp_settings and not mcp_enabled: if request.mcp_settings and not mcp_enabled:
@@ -91,6 +92,7 @@ async def chat_stream(request: ChatRequest):
thread_id = request.thread_id thread_id = request.thread_id
if thread_id == "__default__": if thread_id == "__default__":
thread_id = str(uuid4()) thread_id = str(uuid4())
return StreamingResponse( return StreamingResponse(
_astream_workflow_generator( _astream_workflow_generator(
request.model_dump()["messages"], request.model_dump()["messages"],
@@ -110,6 +112,154 @@ async def chat_stream(request: ChatRequest):
) )
def _process_tool_call_chunks(tool_call_chunks):
"""Process tool call chunks and sanitize arguments."""
chunks = []
for chunk in tool_call_chunks:
chunks.append(
{
"name": chunk.get("name", ""),
"args": sanitize_args(chunk.get("args", "")),
"id": chunk.get("id", ""),
"index": chunk.get("index", 0),
"type": chunk.get("type", ""),
}
)
return chunks
def _get_agent_name(agent, message_metadata):
"""Extract agent name from agent tuple."""
agent_name = "unknown"
if agent and len(agent) > 0:
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0]
else:
agent_name = message_metadata.get("langgraph_node", "unknown")
return agent_name
def _create_event_stream_message(
message_chunk, message_metadata, thread_id, agent_name
):
"""Create base event stream message."""
event_stream_message = {
"thread_id": thread_id,
"agent": agent_name,
"id": message_chunk.id,
"role": "assistant",
"checkpoint_ns": message_metadata.get("checkpoint_ns", ""),
"langgraph_node": message_metadata.get("langgraph_node", ""),
"langgraph_path": message_metadata.get("langgraph_path", ""),
"langgraph_step": message_metadata.get("langgraph_step", ""),
"content": message_chunk.content,
}
# Add optional fields
if message_chunk.additional_kwargs.get("reasoning_content"):
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
"reasoning_content"
]
if message_chunk.response_metadata.get("finish_reason"):
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
"finish_reason"
)
return event_stream_message
def _create_interrupt_event(thread_id, event_data):
"""Create interrupt event."""
return _make_event(
"interrupt",
{
"thread_id": thread_id,
"id": event_data["__interrupt__"][0].ns[0],
"role": "assistant",
"content": event_data["__interrupt__"][0].value,
"finish_reason": "interrupt",
"options": [
{"text": "Edit plan", "value": "edit_plan"},
{"text": "Start research", "value": "accepted"},
],
},
)
def _process_initial_messages(message, thread_id):
"""Process initial messages and yield formatted events."""
json_data = json.dumps(
{
"thread_id": thread_id,
"id": "run--" + message.get("id", uuid4().hex),
"role": "user",
"content": message.get("content", ""),
},
ensure_ascii=False,
separators=(",", ":"),
)
chat_stream_message(
thread_id, f"event: message_chunk\ndata: {json_data}\n\n", "none"
)
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
"""Process a single message chunk and yield appropriate events."""
agent_name = _get_agent_name(agent, message_metadata)
event_stream_message = _create_event_stream_message(
message_chunk, message_metadata, thread_id, agent_name
)
if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
yield _make_event("tool_call_result", event_stream_message)
elif isinstance(message_chunk, AIMessageChunk):
# AI Message - Raw message tokens
if message_chunk.tool_calls:
# AI Message - Tool Call
event_stream_message["tool_calls"] = message_chunk.tool_calls
event_stream_message["tool_call_chunks"] = message_chunk.tool_call_chunks
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
message_chunk.tool_call_chunks
)
yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
message_chunk.tool_call_chunks
)
yield _make_event("tool_call_chunks", event_stream_message)
else:
# AI Message - Raw message tokens
yield _make_event("message_chunk", event_stream_message)
async def _stream_graph_events(
graph_instance, workflow_input, workflow_config, thread_id
):
"""Stream events from the graph and process them."""
async for agent, _, event_data in graph_instance.astream(
workflow_input,
config=workflow_config,
stream_mode=["messages", "updates"],
subgraphs=True,
):
if isinstance(event_data, dict):
if "__interrupt__" in event_data:
yield _create_interrupt_event(thread_id, event_data)
continue
message_chunk, message_metadata = cast(
tuple[BaseMessage, dict[str, any]], event_data
)
async for event in _process_message_chunk(
message_chunk, message_metadata, thread_id, agent
):
yield event
async def _astream_workflow_generator( async def _astream_workflow_generator(
messages: List[dict], messages: List[dict],
thread_id: str, thread_id: str,
@@ -124,7 +274,13 @@ async def _astream_workflow_generator(
report_style: ReportStyle, report_style: ReportStyle,
enable_deep_thinking: bool, enable_deep_thinking: bool,
): ):
input_ = { # Process initial messages
for message in messages:
if isinstance(message, dict) and "content" in message:
_process_initial_messages(message, thread_id)
# Prepare workflow input
workflow_input = {
"messages": messages, "messages": messages,
"plan_iterations": 0, "plan_iterations": 0,
"final_report": "", "final_report": "",
@@ -134,112 +290,105 @@ async def _astream_workflow_generator(
"enable_background_investigation": enable_background_investigation, "enable_background_investigation": enable_background_investigation,
"research_topic": messages[-1]["content"] if messages else "", "research_topic": messages[-1]["content"] if messages else "",
} }
if not auto_accepted_plan and interrupt_feedback: if not auto_accepted_plan and interrupt_feedback:
resume_msg = f"[{interrupt_feedback}]" resume_msg = f"[{interrupt_feedback}]"
# add the last message to the resume message
if messages: if messages:
resume_msg += f" {messages[-1]['content']}" resume_msg += f" {messages[-1]['content']}"
input_ = Command(resume=resume_msg) workflow_input = Command(resume=resume_msg)
async for agent, _, event_data in graph.astream(
input_, # Prepare workflow config
config={ workflow_config = {
"thread_id": thread_id, "thread_id": thread_id,
"resources": resources, "resources": resources,
"max_plan_iterations": max_plan_iterations, "max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num, "max_step_num": max_step_num,
"max_search_results": max_search_results, "max_search_results": max_search_results,
"mcp_settings": mcp_settings, "mcp_settings": mcp_settings,
"report_style": report_style.value, "report_style": report_style.value,
"enable_deep_thinking": enable_deep_thinking, "enable_deep_thinking": enable_deep_thinking,
"recursion_limit": get_recursion_limit(), "recursion_limit": get_recursion_limit(),
}, }
stream_mode=["messages", "updates"],
subgraphs=True, checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
): checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
if isinstance(event_data, dict): # Handle checkpointer if configured
if "__interrupt__" in event_data: connection_kwargs = {
yield _make_event( "autocommit": True,
"interrupt", "row_factory": "dict_row",
{ "prepare_threshold": 0,
"thread_id": thread_id, }
"id": event_data["__interrupt__"][0].ns[0], if checkpoint_saver and checkpoint_url != "":
"role": "assistant", if checkpoint_url.startswith("postgresql://"):
"content": event_data["__interrupt__"][0].value, logger.info("start async postgres checkpointer.")
"finish_reason": "interrupt", async with AsyncConnectionPool(
"options": [ checkpoint_url, kwargs=connection_kwargs
{"text": "Edit plan", "value": "edit_plan"}, ) as conn:
{"text": "Start research", "value": "accepted"}, checkpointer = AsyncPostgresSaver(conn)
], await checkpointer.setup()
}, graph.checkpointer = checkpointer
) graph.store = in_memory_store
continue async for event in _stream_graph_events(
message_chunk, message_metadata = cast( graph, workflow_input, workflow_config, thread_id
tuple[BaseMessage, dict[str, any]], event_data ):
) yield event
# Handle empty agent tuple gracefully
agent_name = "planner" if checkpoint_url.startswith("mongodb://"):
if agent and len(agent) > 0: logger.info("start async mongodb checkpointer.")
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0] async with AsyncMongoDBSaver.from_conn_string(
event_stream_message: dict[str, any] = { checkpoint_url
"thread_id": thread_id, ) as checkpointer:
"agent": agent_name, graph.checkpointer = checkpointer
"id": message_chunk.id, graph.store = in_memory_store
"role": "assistant", async for event in _stream_graph_events(
"content": message_chunk.content, graph, workflow_input, workflow_config, thread_id
} ):
if message_chunk.additional_kwargs.get("reasoning_content"): yield event
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[ else:
"reasoning_content" # Use graph without MongoDB checkpointer
] async for event in _stream_graph_events(
if message_chunk.response_metadata.get("finish_reason"): graph, workflow_input, workflow_config, thread_id
event_stream_message["finish_reason"] = message_chunk.response_metadata.get( ):
"finish_reason" yield event
)
if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
yield _make_event("tool_call_result", event_stream_message)
elif isinstance(message_chunk, AIMessageChunk):
# AI Message - Raw message tokens
if message_chunk.tool_calls:
# AI Message - Tool Call
event_stream_message["tool_calls"] = message_chunk.tool_calls
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_call_chunks", event_stream_message)
else:
# AI Message - Raw message tokens
yield _make_event("message_chunk", event_stream_message)
def _make_event(event_type: str, data: dict[str, any]): def _make_event(event_type: str, data: dict[str, any]):
if data.get("content") == "": if data.get("content") == "":
data.pop("content") data.pop("content")
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" # Ensure JSON serialization with proper encoding
try:
json_data = json.dumps(data, ensure_ascii=False)
finish_reason = data.get("finish_reason", "")
chat_stream_message(
data.get("thread_id", ""),
f"event: {event_type}\ndata: {json_data}\n\n",
finish_reason,
)
return f"event: {event_type}\ndata: {json_data}\n\n"
except (TypeError, ValueError) as e:
logger.error(f"Error serializing event data: {e}")
# Return a safe error event
error_data = json.dumps({"error": "Serialization failed"}, ensure_ascii=False)
return f"event: error\ndata: {error_data}\n\n"
@app.post("/api/tts") @app.post("/api/tts")
async def text_to_speech(request: TTSRequest): async def text_to_speech(request: TTSRequest):
"""Convert text to speech using volcengine TTS API.""" """Convert text to speech using volcengine TTS API."""
app_id = os.getenv("VOLCENGINE_TTS_APPID", "") app_id = get_str_env("VOLCENGINE_TTS_APPID", "")
if not app_id: if not app_id:
raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set") raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set")
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "") access_token = get_str_env("VOLCENGINE_TTS_ACCESS_TOKEN", "")
if not access_token: if not access_token:
raise HTTPException( raise HTTPException(
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set" status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
) )
try: try:
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts") cluster = get_str_env("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming") voice_type = get_str_env("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
tts_client = VolcengineTTS( tts_client = VolcengineTTS(
appid=app_id, appid=app_id,
@@ -382,11 +531,7 @@ async def enhance_prompt(request: EnhancePromptRequest):
async def mcp_server_metadata(request: MCPServerMetadataRequest): async def mcp_server_metadata(request: MCPServerMetadataRequest):
"""Get information about an MCP server.""" """Get information about an MCP server."""
# Check if MCP server configuration is enabled # Check if MCP server configuration is enabled
if os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() not in [ if not get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False):
"true",
"1",
"yes",
]:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.", detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",

View File

@@ -3,11 +3,33 @@
import logging import logging
import json import json
from typing import Any
import json_repair import json_repair
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def sanitize_args(args: Any) -> str:
"""
Sanitize tool call arguments to prevent special character issues.
Args:
args: Tool call arguments string
Returns:
str: Sanitized arguments string
"""
if not isinstance(args, str):
return ""
else:
return (
args.replace("[", "&#91;")
.replace("]", "&#93;")
.replace("{", "&#123;")
.replace("}", "&#125;")
)
def repair_json_output(content: str) -> str: def repair_json_output(content: str) -> str:
""" """
Repair and normalize JSON output. Repair and normalize JSON output.

View File

@@ -0,0 +1,660 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import src.graph.checkpoint as checkpoint
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def test_with_local_postgres_db():
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
def test_with_local_mongo_db():
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
def test_init_without_checkpoint_saver():
"""Manager should not create DB clients when checkpoint_saver is False."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
assert manager.checkpoint_saver is False
# DB connections are not created when saver is disabled
assert manager.mongo_client is None
assert manager.postgres_conn is None
def test_process_stream_partial_buffer_postgres(monkeypatch):
"""Partial chunks should be buffered; Postgres init is stubbed to no-op."""
# Patch Postgres init to no-op
def _no_pg(self):
self.postgres_conn = None
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_postgresql", _no_pg, raising=True
)
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
result = manager.process_stream_message("t1", "hello", finish_reason="partial")
assert result is True
# Verify the chunk was stored in the in-memory store
items = manager.store.search(("messages", "t1"), limit=10)
values = [it.dict()["value"] for it in items]
assert "hello" in values
def test_process_stream_partial_buffer_mongo(monkeypatch):
"""Partial chunks should be buffered; Mongo init is stubbed to no-op."""
# Patch Mongo init to no-op for speed
def _no_mongo(self):
self.mongo_client = None
self.mongo_db = None
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True
)
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
result = manager.process_stream_message("t2", "hello", finish_reason="partial")
assert result is True
# Verify the chunk was stored in the in-memory store
items = manager.store.search(("messages", "t2"), limit=10)
values = [it.dict()["value"] for it in items]
assert "hello" in values
def test_persist_postgresql_local_db():
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_postgresql(thread_id, messages)
assert result is True
# Simulate a message with existing thread
result = manager._persist_to_postgresql(thread_id, ["Another message."])
assert result is True
def test_persist_postgresql_called_with_aggregated_chunks(monkeypatch):
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
assert (
manager.process_stream_message("thd3", "Hello", finish_reason="partial") is True
)
assert (
manager.process_stream_message("thd3", " World", finish_reason="stop") is True
)
# Verify the messages were aggregated correctly
with manager.postgres_conn.cursor() as cursor:
# Check if conversation already exists
cursor.execute(
"SELECT messages FROM chat_streams WHERE thread_id = %s", ("thd3",)
)
existing_record = cursor.fetchone()
assert existing_record is not None
assert existing_record["messages"] == ["Hello", " World"]
def test_persist_not_attempted_when_saver_disabled():
"""When saver disabled, stop should not persist and should return False."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
# stop should try to persist, but saver disabled => returns False
assert manager.process_stream_message("t4", "hello", finish_reason="stop") is False
def test_persist_mongodb_local_db():
"""Ensure that the ChatStreamManager can persist to a local MongoDB."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_mongodb(thread_id, messages)
assert result is True
# Simulate a message with existing thread
result = manager._persist_to_mongodb(thread_id, ["Another message."])
assert result is True
def test_persist_mongodb_called_with_aggregated_chunks(monkeypatch):
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert (
manager.process_stream_message("thd5", "Hello", finish_reason="partial") is True
)
assert (
manager.process_stream_message("thd5", " World", finish_reason="stop") is True
)
# Verify the messages were aggregated correctly
collection = manager.mongo_db.chat_streams
existing_record = collection.find_one({"thread_id": "thd5"})
assert existing_record is not None
assert existing_record["messages"] == ["Hello", " World"]
def test_invalid_inputs_return_false(monkeypatch):
"""Empty thread_id or message should be rejected and return False."""
def _no_mongo(self):
self.mongo_client = None
self.mongo_db = None
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True
)
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.process_stream_message("", "msg", finish_reason="partial") is False
assert manager.process_stream_message("tid", "", finish_reason="partial") is False
def test_unsupported_db_uri_scheme():
"""Manager should log warning for unsupported database URI schemes."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True, db_uri="redis://localhost:6379/0"
)
# Should not have any database connections
assert manager.mongo_client is None
assert manager.postgres_conn is None
assert manager.mongo_db is None
def test_process_stream_with_interrupt_finish_reason():
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
# Add partial message
assert (
manager.process_stream_message(
"int_test", "Interrupted", finish_reason="partial"
)
is True
)
# Interrupt should trigger persistence
assert (
manager.process_stream_message(
"int_test", " message", finish_reason="interrupt"
)
is True
)
def test_postgresql_connection_failure(monkeypatch):
"""Test PostgreSQL connection failure handling."""
def failing_connect(dsn, **kwargs):
raise RuntimeError("Connection failed")
monkeypatch.setattr("psycopg.connect", failing_connect)
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
# Should have no postgres connection on failure
assert manager.postgres_conn is None
def test_mongodb_ping_failure(monkeypatch):
"""Test MongoDB ping failure during initialization."""
class FakeAdmin:
def command(self, name):
raise RuntimeError("Ping failed")
class FakeClient:
def __init__(self, uri):
self.admin = FakeAdmin()
monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri))
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
# Should not have mongo_db set on ping failure
assert getattr(manager, "mongo_db", None) is None
def test_store_namespace_consistency():
"""Test that store namespace is consistently used across methods."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
# Process a partial message
assert (
manager.process_stream_message("ns_test", "chunk1", finish_reason="partial")
is True
)
# Verify cursor is stored correctly
cursor = manager.store.get(("messages", "ns_test"), "cursor")
assert cursor is not None
assert cursor.value["index"] == 0
# Add another chunk
assert (
manager.process_stream_message("ns_test", "chunk2", finish_reason="partial")
is True
)
# Verify cursor is incremented
cursor = manager.store.get(("messages", "ns_test"), "cursor")
assert cursor.value["index"] == 1
def test_cursor_initialization_edge_cases():
"""Test cursor handling edge cases."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
# Manually set a cursor with missing index
namespace = ("messages", "edge_test")
manager.store.put(namespace, "cursor", {}) # Missing 'index' key
# Should handle missing index gracefully
result = manager.process_stream_message(
"edge_test", "test", finish_reason="partial"
)
assert result is True
# Should default to 0 and increment to 1
cursor = manager.store.get(namespace, "cursor")
assert cursor.value["index"] == 1
def test_multiple_threads_isolation():
"""Test that different thread_ids are properly isolated."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
# Process messages for different threads
assert (
manager.process_stream_message("thread1", "msg1", finish_reason="partial")
is True
)
assert (
manager.process_stream_message("thread2", "msg2", finish_reason="partial")
is True
)
assert (
manager.process_stream_message("thread1", "msg3", finish_reason="partial")
is True
)
# Verify isolation
thread1_items = manager.store.search(("messages", "thread1"), limit=10)
thread2_items = manager.store.search(("messages", "thread2"), limit=10)
thread1_values = [
item.dict()["value"]
for item in thread1_items
if isinstance(item.dict()["value"], str)
]
thread2_values = [
item.dict()["value"]
for item in thread2_items
if isinstance(item.dict()["value"], str)
]
assert "msg1" in thread1_values
assert "msg3" in thread1_values
assert "msg2" in thread2_values
assert "msg1" not in thread2_values
assert "msg2" not in thread1_values
def test_mongodb_insert_and_update_paths(monkeypatch):
"""Exercise MongoDB insert, update, and exception branches."""
# Fake Mongo classes
class FakeUpdateResult:
def __init__(self, modified_count):
self.modified_count = modified_count
class FakeInsertResult:
def __init__(self, inserted_id):
self.inserted_id = inserted_id
class FakeCollection:
def __init__(self, mode="insert_success"):
self.mode = mode
def find_one(self, query):
if self.mode.startswith("insert"):
return None
return {"thread_id": query["thread_id"]}
def update_one(self, q, s):
if self.mode == "update_success":
return FakeUpdateResult(1)
return FakeUpdateResult(0)
def insert_one(self, doc):
if self.mode == "insert_success":
return FakeInsertResult("ok")
if self.mode == "insert_none":
return FakeInsertResult(None)
raise RuntimeError("boom")
class FakeMongoDB:
def __init__(self, mode):
self.chat_streams = FakeCollection(mode)
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Insert success
manager.mongo_db = FakeMongoDB("insert_success")
assert manager._persist_to_mongodb("th1", ["a"]) is True
# Insert returns None id => False
manager.mongo_db = FakeMongoDB("insert_none")
assert manager._persist_to_mongodb("th2", ["a"]) is False
# Insert raises => False
manager.mongo_db = FakeMongoDB("insert_raise")
assert manager._persist_to_mongodb("th3", ["a"]) is False
# Update success
manager.mongo_db = FakeMongoDB("update_success")
assert manager._persist_to_mongodb("th4", ["a"]) is True
# Update modifies 0 => False
manager.mongo_db = FakeMongoDB("update_zero")
assert manager._persist_to_mongodb("th5", ["a"]) is False
def test_postgresql_insert_update_and_error_paths():
"""Exercise PostgreSQL update, insert, and error/rollback branches."""
calls = {"executed": []}
class FakeCursor:
def __init__(self, mode):
self.mode = mode
self.rowcount = 0
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, sql, params=None):
calls["executed"].append(sql.strip().split()[0])
if "SELECT" in sql:
if self.mode == "update":
self._fetch = {"id": "x"}
elif self.mode == "error":
raise RuntimeError("sql error")
else:
self._fetch = None
else:
# UPDATE or INSERT
self.rowcount = 1
def fetchone(self):
return getattr(self, "_fetch", None)
class FakeConn:
def __init__(self, mode):
self.mode = mode
self.commit_called = False
self.rollback_called = False
def cursor(self):
return FakeCursor(self.mode)
def commit(self):
self.commit_called = True
def rollback(self):
self.rollback_called = True
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
# Update path
manager.postgres_conn = FakeConn("update")
assert manager._persist_to_postgresql("t", ["m"]) is True
assert manager.postgres_conn.commit_called is True
# Insert path
manager.postgres_conn = FakeConn("insert")
assert manager._persist_to_postgresql("t", ["m"]) is True
assert manager.postgres_conn.commit_called is True
# Error path with rollback
manager.postgres_conn = FakeConn("error")
assert manager._persist_to_postgresql("t", ["m"]) is False
assert manager.postgres_conn.rollback_called is True
def test_create_chat_streams_table_success_and_error():
"""Ensure table creation commits on success and rolls back on failure."""
class FakeCursor:
def __init__(self, should_fail=False):
self.should_fail = should_fail
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, sql):
if self.should_fail:
raise RuntimeError("ddl fail")
class FakeConn:
def __init__(self, should_fail=False):
self.should_fail = should_fail
self.commits = 0
self.rollbacks = 0
def cursor(self):
return FakeCursor(self.should_fail)
def commit(self):
self.commits += 1
def rollback(self):
self.rollbacks += 1
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
# Success
manager.postgres_conn = FakeConn(False)
manager._create_chat_streams_table()
assert manager.postgres_conn.commits == 1
# Failure triggers rollback
manager.postgres_conn = FakeConn(True)
manager._create_chat_streams_table()
assert manager.postgres_conn.rollbacks == 1
def test_close_closes_resources_and_handles_errors():
"""Close should gracefully handle both success and exceptions."""
flags = {"mongo": 0, "pg": 0}
class M:
def close(self):
flags["mongo"] += 1
class P:
def __init__(self, raise_on_close=False):
self.raise_on_close = raise_on_close
def close(self):
if self.raise_on_close:
raise RuntimeError("close fail")
flags["pg"] += 1
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
manager.mongo_client = M()
manager.postgres_conn = P()
manager.close()
assert flags == {"mongo": 1, "pg": 1}
# Trigger error branches (no raise escapes)
manager.mongo_client = None # skip mongo
manager.postgres_conn = P(True)
manager.close() # should handle exception gracefully
def test_context_manager_calls_close(monkeypatch):
"""The context manager protocol should call close() on exit."""
called = {"close": 0}
def _noop(self):
self.mongo_client = None
self.mongo_db = None
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_mongodb", _noop, raising=True
)
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
def fake_close():
called["close"] += 1
manager.close = fake_close
with manager:
pass
assert called["close"] == 1
def test_init_mongodb_success_and_failure(monkeypatch):
"""MongoDB init should succeed with a valid client and fail gracefully otherwise."""
class FakeAdmin:
def command(self, name):
assert name == "ping"
class DummyDB:
pass
class FakeClient:
def __init__(self, uri):
self.uri = uri
self.admin = FakeAdmin()
self.checkpointing_db = DummyDB()
def close(self):
pass
# Success path
monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri))
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
assert manager.mongo_db is not None
# Failure path
class Boom:
def __init__(self, uri):
raise RuntimeError("fail connect")
monkeypatch.setattr(checkpoint, "MongoClient", Boom)
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Should have no mongo_db set on failure
assert getattr(manager, "mongo_db", None) is None
def test_init_postgresql_calls_connect_and_create_table(monkeypatch):
"""PostgreSQL init should connect and create the required table."""
flags = {"connected": 0, "created": 0}
class FakeConn:
def __init__(self):
pass
def close(self):
pass
def fake_connect(self):
flags["connected"] += 1
flags["created"] += 1
return FakeConn()
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_postgresql", fake_connect, raising=True
)
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
assert manager.postgres_conn is None
assert flags == {"connected": 1, "created": 1}
def test_chat_stream_message_wrapper(monkeypatch):
"""Wrapper should delegate when enabled and return False when disabled."""
# When saver enabled, should call default manager
monkeypatch.setattr(
checkpoint, "get_bool_env", lambda k, d=False: True, raising=True
)
called = {"args": None}
def fake_process(tid, msg, fr):
called["args"] = (tid, msg, fr)
return True
monkeypatch.setattr(
checkpoint._default_manager,
"process_stream_message",
fake_process,
raising=True,
)
assert checkpoint.chat_stream_message("tid", "msg", "stop") is True
assert called["args"] == ("tid", "msg", "stop")
# When saver disabled, returns False and does not call manager
monkeypatch.setattr(
checkpoint, "get_bool_env", lambda k, d=False: False, raising=True
)
called["args"] = None
assert checkpoint.chat_stream_message("tid", "msg", "stop") is False
assert called["args"] is None

146
uv.lock generated
View File

@@ -1,5 +1,5 @@
version = 1 version = 1
revision = 2 revision = 3
requires-python = ">=3.12" requires-python = ">=3.12"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13'", "python_full_version >= '3.13'",
@@ -383,6 +383,8 @@ dependencies = [
{ name = "langchain-mcp-adapters" }, { name = "langchain-mcp-adapters" },
{ name = "langchain-openai" }, { name = "langchain-openai" },
{ name = "langgraph" }, { name = "langgraph" },
{ name = "langgraph-checkpoint-mongodb" },
{ name = "langgraph-checkpoint-postgres" },
{ name = "litellm" }, { name = "litellm" },
{ name = "markdownify" }, { name = "markdownify" },
{ name = "mcp" }, { name = "mcp" },
@@ -425,6 +427,8 @@ requires-dist = [
{ name = "langchain-mcp-adapters", specifier = ">=0.0.9" }, { name = "langchain-mcp-adapters", specifier = ">=0.0.9" },
{ name = "langchain-openai", specifier = ">=0.3.8" }, { name = "langchain-openai", specifier = ">=0.3.8" },
{ name = "langgraph", specifier = ">=0.3.5" }, { name = "langgraph", specifier = ">=0.3.5" },
{ name = "langgraph-checkpoint-mongodb", specifier = ">=0.1.4" },
{ name = "langgraph-checkpoint-postgres", specifier = "==2.0.21" },
{ name = "langgraph-cli", extras = ["inmem"], marker = "extra == 'dev'", specifier = ">=0.2.10" }, { name = "langgraph-cli", extras = ["inmem"], marker = "extra == 'dev'", specifier = ">=0.2.10" },
{ name = "litellm", specifier = ">=1.63.11" }, { name = "litellm", specifier = ">=1.63.11" },
{ name = "markdownify", specifier = ">=1.1.0" }, { name = "markdownify", specifier = ">=1.1.0" },
@@ -454,6 +458,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
] ]
[[package]]
name = "dnspython"
version = "2.7.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b5/4a/263763cb2ba3816dd94b08ad3a33d5fdae34ecb856678773cc40a3605829/dnspython-2.7.0.tar.gz", hash = "sha256:ce9c432eda0dc91cf618a5cedf1a4e142651196bbcd2c80e89ed5a907e5cfaf1", size = 345197, upload-time = "2024-10-05T20:14:59.362Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632, upload-time = "2024-10-05T20:14:57.687Z" },
]
[[package]] [[package]]
name = "duckduckgo-search" name = "duckduckgo-search"
version = "8.0.0" version = "8.0.0"
@@ -946,6 +959,23 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/eb/9e98822d3db22beff44449a8f61fca208d4f59d592a7ce67ce4c400b8f8f/langchain_mcp_adapters-0.1.9-py3-none-any.whl", hash = "sha256:fd131009c60c9e5a864f96576bbe757fc1809abd604891cb2e5d6e8aebd6975c", size = 15300, upload-time = "2025-07-09T15:56:13.316Z" }, { url = "https://files.pythonhosted.org/packages/a4/eb/9e98822d3db22beff44449a8f61fca208d4f59d592a7ce67ce4c400b8f8f/langchain_mcp_adapters-0.1.9-py3-none-any.whl", hash = "sha256:fd131009c60c9e5a864f96576bbe757fc1809abd604891cb2e5d6e8aebd6975c", size = 15300, upload-time = "2025-07-09T15:56:13.316Z" },
] ]
[[package]]
name = "langchain-mongodb"
version = "0.6.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain" },
{ name = "langchain-core" },
{ name = "langchain-text-splitters" },
{ name = "lark" },
{ name = "numpy" },
{ name = "pymongo" },
]
sdist = { url = "https://files.pythonhosted.org/packages/02/82/ea0cfd092843f09c71d51d9b1fc3a5051fa477fb5d94b95e3b7bf73d6fd2/langchain_mongodb-0.6.2.tar.gz", hash = "sha256:fae221017b5db8a239837b2d163cf6493de6072e217239a95e3aa1fc3303615e", size = 237734, upload-time = "2025-05-12T14:37:51.964Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/15/fa/d47c1a3dd7ff07709023ff75284544a0d315c51e6de69ac01ff90a642dfe/langchain_mongodb-0.6.2-py3-none-any.whl", hash = "sha256:17f740e16582b8b6b241e625fb5c0ae273af1ccb13f2877ade1f1f14049e51a0", size = 59133, upload-time = "2025-05-12T14:37:50.603Z" },
]
[[package]] [[package]]
name = "langchain-openai" name = "langchain-openai"
version = "0.3.22" version = "0.3.22"
@@ -1032,6 +1062,36 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/12/52/bceb5b5348c7a60ef0625ab0a0a0a9ff5d78f0e12aed8cc55c49d5e8a8c9/langgraph_checkpoint-2.0.25-py3-none-any.whl", hash = "sha256:23416a0f5bc9dd712ac10918fc13e8c9c4530c419d2985a441df71a38fc81602", size = 42312, upload-time = "2025-04-26T21:00:42.242Z" }, { url = "https://files.pythonhosted.org/packages/12/52/bceb5b5348c7a60ef0625ab0a0a0a9ff5d78f0e12aed8cc55c49d5e8a8c9/langgraph_checkpoint-2.0.25-py3-none-any.whl", hash = "sha256:23416a0f5bc9dd712ac10918fc13e8c9c4530c419d2985a441df71a38fc81602", size = 42312, upload-time = "2025-04-26T21:00:42.242Z" },
] ]
[[package]]
name = "langgraph-checkpoint-mongodb"
version = "0.1.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-mongodb" },
{ name = "langgraph-checkpoint" },
{ name = "motor" },
{ name = "pymongo" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f5/c8/062bef92e96a36ea14658f16c43b87892a98c65137b4e088044790960e91/langgraph_checkpoint_mongodb-0.1.4.tar.gz", hash = "sha256:cae9a63a80d8259388b23e941438b7ae56e20570c1f39f640ccb9f28f77a67fe", size = 144572, upload-time = "2025-06-13T20:20:06.563Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/9b/b78c4366cf21bb908f90e348999b9353f85d7c2e40a4e169bb9ebe6ff37a/langgraph_checkpoint_mongodb-0.1.4-py3-none-any.whl", hash = "sha256:5edfa15e0fc03c27b7dda840a001bd210f16abc75b25b64405bf901ca655fdb5", size = 11226, upload-time = "2025-06-13T20:20:05.855Z" },
]
[[package]]
name = "langgraph-checkpoint-postgres"
version = "2.0.21"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langgraph-checkpoint" },
{ name = "orjson" },
{ name = "psycopg" },
{ name = "psycopg-pool" },
]
sdist = { url = "https://files.pythonhosted.org/packages/90/51/18138b116807180c97093890dd7955f0e68eabcba97d323a61275bab45b6/langgraph_checkpoint_postgres-2.0.21.tar.gz", hash = "sha256:921915fd3de534b4c84469f93d03046c1ef1f224e44629212b172ec3e9b72ded", size = 30371, upload-time = "2025-04-18T16:31:50.164Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fd/31/d5f4a7dd63dddfdb85209a3cbc1778b14bc0dddadb431e34938956f45e8c/langgraph_checkpoint_postgres-2.0.21-py3-none-any.whl", hash = "sha256:f0a50f2c1496778e00ea888415521bb2b7789a12052aa5ae54d82cf517b271e8", size = 39440, upload-time = "2025-04-18T16:31:48.838Z" },
]
[[package]] [[package]]
name = "langgraph-cli" name = "langgraph-cli"
version = "0.2.10" version = "0.2.10"
@@ -1112,6 +1172,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/f4/c206c0888f8a506404cb4f16ad89593bdc2f70cf00de26a1a0a7a76ad7a3/langsmith-0.3.45-py3-none-any.whl", hash = "sha256:5b55f0518601fa65f3bb6b1a3100379a96aa7b3ed5e9380581615ba9c65ed8ed", size = 363002, upload-time = "2025-06-05T05:10:27.228Z" }, { url = "https://files.pythonhosted.org/packages/6a/f4/c206c0888f8a506404cb4f16ad89593bdc2f70cf00de26a1a0a7a76ad7a3/langsmith-0.3.45-py3-none-any.whl", hash = "sha256:5b55f0518601fa65f3bb6b1a3100379a96aa7b3ed5e9380581615ba9c65ed8ed", size = 363002, upload-time = "2025-06-05T05:10:27.228Z" },
] ]
[[package]]
name = "lark"
version = "1.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/af/60/bc7622aefb2aee1c0b4ba23c1446d3e30225c8770b38d7aedbfb65ca9d5a/lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80", size = 252132, upload-time = "2024-08-13T19:49:00.652Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2d/00/d90b10b962b4277f5e64a78b6609968859ff86889f5b898c1a778c06ec00/lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c", size = 111036, upload-time = "2024-08-13T19:48:58.603Z" },
]
[[package]] [[package]]
name = "litellm" name = "litellm"
version = "1.63.11" version = "1.63.11"
@@ -1261,6 +1330,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2f/cf/3fd38cfe43962452e4bfadc6966b2ea0afaf8e0286cb3991c247c8c33ebd/mcp-1.12.2-py3-none-any.whl", hash = "sha256:b86d584bb60193a42bd78aef01882c5c42d614e416cbf0480149839377ab5a5f", size = 158473, upload-time = "2025-07-24T18:29:03.419Z" }, { url = "https://files.pythonhosted.org/packages/2f/cf/3fd38cfe43962452e4bfadc6966b2ea0afaf8e0286cb3991c247c8c33ebd/mcp-1.12.2-py3-none-any.whl", hash = "sha256:b86d584bb60193a42bd78aef01882c5c42d614e416cbf0480149839377ab5a5f", size = 158473, upload-time = "2025-07-24T18:29:03.419Z" },
] ]
[[package]]
name = "motor"
version = "3.7.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pymongo" },
]
sdist = { url = "https://files.pythonhosted.org/packages/93/ae/96b88362d6a84cb372f7977750ac2a8aed7b2053eed260615df08d5c84f4/motor-3.7.1.tar.gz", hash = "sha256:27b4d46625c87928f331a6ca9d7c51c2f518ba0e270939d395bc1ddc89d64526", size = 280997, upload-time = "2025-05-14T18:56:33.653Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/01/9a/35e053d4f442addf751ed20e0e922476508ee580786546d699b0567c4c67/motor-3.7.1-py3-none-any.whl", hash = "sha256:8a63b9049e38eeeb56b4fdd57c3312a6d1f25d01db717fe7d82222393c410298", size = 74996, upload-time = "2025-05-14T18:56:31.665Z" },
]
[[package]] [[package]]
name = "multidict" name = "multidict"
version = "6.1.0" version = "6.1.0"
@@ -1603,6 +1684,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/35/6c4c6fc8774a9e3629cd750dc24a7a4fb090a25ccd5c3246d127b70f9e22/propcache-0.3.0-py3-none-any.whl", hash = "sha256:67dda3c7325691c2081510e92c561f465ba61b975f481735aefdfc845d2cd043", size = 12101, upload-time = "2025-02-20T19:03:27.202Z" }, { url = "https://files.pythonhosted.org/packages/b5/35/6c4c6fc8774a9e3629cd750dc24a7a4fb090a25ccd5c3246d127b70f9e22/propcache-0.3.0-py3-none-any.whl", hash = "sha256:67dda3c7325691c2081510e92c561f465ba61b975f481735aefdfc845d2cd043", size = 12101, upload-time = "2025-02-20T19:03:27.202Z" },
] ]
[[package]]
name = "psycopg"
version = "3.2.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
{ name = "tzdata", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/27/4a/93a6ab570a8d1a4ad171a1f4256e205ce48d828781312c0bbaff36380ecb/psycopg-3.2.9.tar.gz", hash = "sha256:2fbb46fcd17bc81f993f28c47f1ebea38d66ae97cc2dbc3cad73b37cefbff700", size = 158122, upload-time = "2025-05-13T16:11:15.533Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/44/b0/a73c195a56eb6b92e937a5ca58521a5c3346fb233345adc80fd3e2f542e2/psycopg-3.2.9-py3-none-any.whl", hash = "sha256:01a8dadccdaac2123c916208c96e06631641c0566b22005493f09663c7a8d3b6", size = 202705, upload-time = "2025-05-13T16:06:26.584Z" },
]
[[package]]
name = "psycopg-pool"
version = "3.2.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cf/13/1e7850bb2c69a63267c3dbf37387d3f71a00fd0e2fa55c5db14d64ba1af4/psycopg_pool-3.2.6.tar.gz", hash = "sha256:0f92a7817719517212fbfe2fd58b8c35c1850cdd2a80d36b581ba2085d9148e5", size = 29770, upload-time = "2025-02-26T12:03:47.129Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/47/fd/4feb52a55c1a4bd748f2acaed1903ab54a723c47f6d0242780f4d97104d4/psycopg_pool-3.2.6-py3-none-any.whl", hash = "sha256:5887318a9f6af906d041a0b1dc1c60f8f0dda8340c2572b74e10907b51ed5da7", size = 38252, upload-time = "2025-02-26T12:03:45.073Z" },
]
[[package]] [[package]]
name = "pycparser" name = "pycparser"
version = "2.22" version = "2.22"
@@ -1687,6 +1793,44 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
] ]
[[package]]
name = "pymongo"
version = "4.12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dnspython" },
]
sdist = { url = "https://files.pythonhosted.org/packages/85/27/3634b2e8d88ad210ee6edac69259c698aefed4a79f0f7356cd625d5c423c/pymongo-4.12.1.tar.gz", hash = "sha256:8921bac7f98cccb593d76c4d8eaa1447e7d537ba9a2a202973e92372a05bd1eb", size = 2165515, upload-time = "2025-04-29T18:46:23.62Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/59/dd/b684de28bfaf7e296538601c514d4613f98b77cfa1de323c7b160f4e04d0/pymongo-4.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a7b771aa2f0854ddf7861e8ce2365f29df9159393543d047e43d8475bc4b8813", size = 910797, upload-time = "2025-04-29T18:44:57.783Z" },
{ url = "https://files.pythonhosted.org/packages/e8/80/4fadd5400a4fbe57e7ea0349f132461d5dfc46c124937600f5044290d817/pymongo-4.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34fd8681b6fa6e1025dd1000004f6b81cbf1961f145b8c58bd15e3957976068d", size = 910489, upload-time = "2025-04-29T18:45:01.089Z" },
{ url = "https://files.pythonhosted.org/packages/4e/83/303be22944312cc28e3a357556d21971c388189bf90aebc79e752afa2452/pymongo-4.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981e19b8f1040247dee5f7879e45f640f7e21a4d87eabb19283ce5a2927dd2e7", size = 1689142, upload-time = "2025-04-29T18:45:03.008Z" },
{ url = "https://files.pythonhosted.org/packages/a4/67/f4e8506caf001ab9464df2562e3e022b7324e7c10a979ce1b55b006f2445/pymongo-4.12.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9a487dc1fe92736987a156325d3d9c66cbde6eac658b2875f5f222b6d82edca", size = 1753373, upload-time = "2025-04-29T18:45:04.874Z" },
{ url = "https://files.pythonhosted.org/packages/2e/7c/22d65c2a4e3e941b345b8cc164b3b53f2c1d0db581d4991817b6375ef507/pymongo-4.12.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1525051c13984365c4a9b88ee2d63009fae277921bc89a0d323b52c51f91cbac", size = 1722399, upload-time = "2025-04-29T18:45:06.726Z" },
{ url = "https://files.pythonhosted.org/packages/07/0d/32fd1ebafd0090510fb4820d175fe35d646e5b28c71ad9c36cb3ce554567/pymongo-4.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ad689e0e4f364809084f9e5888b2dcd6f0431b682a1c68f3fdf241e20e14475", size = 1692374, upload-time = "2025-04-29T18:45:08.552Z" },
{ url = "https://files.pythonhosted.org/packages/e3/9c/d7a30ce6b983c3955c225e3038dafb4f299281775323f58b378f2a7e6e59/pymongo-4.12.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f9b18abca210c2917041ab2a380c12f6ddd2810844f1d64afb39caf8a15425e", size = 1651490, upload-time = "2025-04-29T18:45:10.658Z" },
{ url = "https://files.pythonhosted.org/packages/29/b3/7902d73df1d088ec0c60c19ef4bd7894c6e6e4dfbfd7ab4ae4fbedc9427c/pymongo-4.12.1-cp312-cp312-win32.whl", hash = "sha256:d9d90fec041c6d695a639c26ca83577aa74383f5e3744fd7931537b208d5a1b5", size = 879521, upload-time = "2025-04-29T18:45:12.993Z" },
{ url = "https://files.pythonhosted.org/packages/8c/68/a17ff6472e6be12bae75f5d11db4e3dccc55e02dcd4e66cd87871790a20e/pymongo-4.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:d004b13e4f03d73a3ad38505ba84b61a2c8ba0a304f02fe1b27bfc986c244192", size = 897765, upload-time = "2025-04-29T18:45:15.296Z" },
{ url = "https://files.pythonhosted.org/packages/0c/4d/e6654f3ec6819980cbad77795ccf2275cd65d6df41375a22cdbbccef8416/pymongo-4.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:90de2b060d69c22658ada162a5380a0f88cb8c0149023241b9e379732bd36152", size = 965051, upload-time = "2025-04-29T18:45:17.516Z" },
{ url = "https://files.pythonhosted.org/packages/54/95/627a047c32789544a938abfd9311c914e622cb036ad16866e7e1b9b80239/pymongo-4.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:edf4e05331ac875d3b27b4654b74d81e44607af4aa7d6bcd4a31801ca164e6fd", size = 964732, upload-time = "2025-04-29T18:45:19.478Z" },
{ url = "https://files.pythonhosted.org/packages/8f/6d/7a604e3ab5399f8fe1ca88abdbf7e54ceb6cf03e64f68b2ed192d9a5eaf5/pymongo-4.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa7a817c9afb7b8775d98c469ddb3fe9c17daf53225394c1a74893cf45d3ade9", size = 1953037, upload-time = "2025-04-29T18:45:22.115Z" },
{ url = "https://files.pythonhosted.org/packages/d5/d5/269388e7b0d02d35f55440baf1e0120320b6db1b555eaed7117d04b35402/pymongo-4.12.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d142ca531694e9324b3c9ba86c0e905c5f857599c4018a386c4dc02ca490fa", size = 2030467, upload-time = "2025-04-29T18:45:24.069Z" },
{ url = "https://files.pythonhosted.org/packages/4b/d0/04a6b48d6ca3fc2ff156185a3580799a748cf713239d6181e91234a663d3/pymongo-4.12.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5d4c0461f5cd84d9fe87d5a84b1bc16371c4dd64d56dcfe5e69b15c0545a5ac", size = 1994139, upload-time = "2025-04-29T18:45:26.215Z" },
{ url = "https://files.pythonhosted.org/packages/ad/65/0567052d52c0ac8aaa4baa700b39cdd1cf2481d2e59bd9817a3daf169ca0/pymongo-4.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43afd2f39182731ac9fb81bbc9439d539e4bd2eda72cdee829d2fa906a1c4d37", size = 1954947, upload-time = "2025-04-29T18:45:28.423Z" },
{ url = "https://files.pythonhosted.org/packages/c5/5b/db25747b288218dbdd97e9aeff6a3bfa3f872efb4ed06fa8bec67b2a121e/pymongo-4.12.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:827ac668c003da7b175b8e5f521850e2c182b4638a3dec96d97f0866d5508a1e", size = 1904374, upload-time = "2025-04-29T18:45:30.943Z" },
{ url = "https://files.pythonhosted.org/packages/fc/1e/6d0eb040c02ae655fafd63bd737e96d7e832eecfd0bd37074d0066f94a78/pymongo-4.12.1-cp313-cp313-win32.whl", hash = "sha256:7c2269b37f034124a245eaeb34ce031cee64610437bd597d4a883304babda3cd", size = 925869, upload-time = "2025-04-29T18:45:32.998Z" },
{ url = "https://files.pythonhosted.org/packages/59/b9/459da646d9750529f04e7e686f0cd8dd40174138826574885da334c01b16/pymongo-4.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:3b28ecd1305b89089be14f137ffbdf98a3b9f5c8dbbb2be4dec084f2813fbd5f", size = 948411, upload-time = "2025-04-29T18:45:35.445Z" },
{ url = "https://files.pythonhosted.org/packages/c9/c3/75be116159f210811656ec615b2248f63f1bc9dd1ce641e18db2552160f0/pymongo-4.12.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:f27b22a8215caff68bdf46b5b61ccd843a68334f2aa4658e8d5ecb5d3fbebb3b", size = 1021562, upload-time = "2025-04-29T18:45:37.433Z" },
{ url = "https://files.pythonhosted.org/packages/cd/d1/2e8e368cad1c126a68365a6f53feaade58f9a16bd5f7a69f218af119b0e9/pymongo-4.12.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5e9d23a3c290cf7409515466a7f11069b70e38ea2b786bbd7437bdc766c9e176", size = 1021553, upload-time = "2025-04-29T18:45:39.344Z" },
{ url = "https://files.pythonhosted.org/packages/17/6e/a6460bc1e3d3f5f46cc151417427b2687a6f87972fd68a33961a37c114df/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efeb430f7ca8649a6544a50caefead343d1fd096d04b6b6a002c6ce81148a85c", size = 2281736, upload-time = "2025-04-29T18:45:41.462Z" },
{ url = "https://files.pythonhosted.org/packages/1a/e2/9e1d6f1a492bb02116074baa832716805a0552d757c176e7c5f40867ca80/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a34e4a08bbcff56fdee86846afbc9ce751de95706ca189463e01bf5de3dd9927", size = 2368964, upload-time = "2025-04-29T18:45:43.579Z" },
{ url = "https://files.pythonhosted.org/packages/fa/df/88143016eca77e79e38cf072476c70dd360962934430447dabc9c6bef6df/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b063344e0282537f05dbb11147591cbf58fc09211e24fc374749e343f880910a", size = 2327834, upload-time = "2025-04-29T18:45:45.847Z" },
{ url = "https://files.pythonhosted.org/packages/3c/0d/df2998959b52cd5682b11e6eee1b0e0c104c07abd99c9cde5a871bb299fd/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3f7941e01b3e5d4bfb3b4711425e809df8c471b92d1da8d6fab92c7e334a4cb", size = 2279126, upload-time = "2025-04-29T18:45:48.445Z" },
{ url = "https://files.pythonhosted.org/packages/fb/3e/102636f5aaf97ccfa2a156c253a89f234856a0cd252fa602d4bf077ba3c0/pymongo-4.12.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b41235014031739f32be37ff13992f51091dae9a5189d3bcc22a5bf81fd90dae", size = 2218136, upload-time = "2025-04-29T18:45:50.57Z" },
{ url = "https://files.pythonhosted.org/packages/44/c9/1b534c9d8d91d9d98310f2d955c5331fb522bd2a0105bd1fc31771d53758/pymongo-4.12.1-cp313-cp313t-win32.whl", hash = "sha256:9a1f07fe83a8a34651257179bd38d0f87bd9d90577fcca23364145c5e8ba1bc0", size = 974747, upload-time = "2025-04-29T18:45:52.66Z" },
{ url = "https://files.pythonhosted.org/packages/08/e2/7d3a30ac905c99ea93729e03d2bb3d16fec26a789e98407d61cb368ab4bb/pymongo-4.12.1-cp313-cp313t-win_amd64.whl", hash = "sha256:46d86cf91ee9609d0713242a1d99fa9e9c60b4315e1a067b9a9e769bedae629d", size = 1003332, upload-time = "2025-04-29T18:45:54.631Z" },
]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "8.3.5" version = "8.3.5"

View File

@@ -49,7 +49,11 @@ function mergeTextMessage(message: Message, event: MessageChunkEvent) {
message.reasoningContentChunks.push(event.data.reasoning_content); message.reasoningContentChunks.push(event.data.reasoning_content);
} }
} }
function convertToolChunkArgs(args: string) {
// Convert escaped characters in args
if (!args) return "";
return args.replace(/&#91;/g, "[").replace(/&#93;/g, "]").replace(/&#123;/g, "{").replace(/&#125;/g, "}");
}
function mergeToolCallMessage( function mergeToolCallMessage(
message: Message, message: Message,
event: ToolCallsEvent | ToolCallChunksEvent, event: ToolCallsEvent | ToolCallChunksEvent,
@@ -70,14 +74,14 @@ function mergeToolCallMessage(
(toolCall) => toolCall.id === chunk.id, (toolCall) => toolCall.id === chunk.id,
); );
if (toolCall) { if (toolCall) {
toolCall.argsChunks = [chunk.args]; toolCall.argsChunks = [convertToolChunkArgs(chunk.args)];
} }
} else { } else {
const streamingToolCall = message.toolCalls.find( const streamingToolCall = message.toolCalls.find(
(toolCall) => toolCall.argsChunks?.length, (toolCall) => toolCall.argsChunks?.length,
); );
if (streamingToolCall) { if (streamingToolCall) {
streamingToolCall.argsChunks!.push(chunk.args); streamingToolCall.argsChunks!.push(convertToolChunkArgs(chunk.args));
} }
} }
} }