Files
deer-flow/src/graph/checkpoint.py
YikB 7cd2265272 append messages to chat_streams table (#816)
* feat: Implement DeerFlow API server with chat streaming, Langgraph orchestration, and various content generation capabilities.

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* - Use MongoDB `$push` with `$each` to append new messages to existing threads
- Use PostgreSQL jsonb concatenation operator to merge messages instead of overwriting
- Update comments to reflect append behavior in both database implementations

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-22 09:09:15 +08:00

389 lines
15 KiB
Python

# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
import logging
import uuid
from datetime import datetime
from typing import List, Optional, Tuple
import psycopg
from langgraph.store.memory import InMemoryStore
from psycopg.rows import dict_row
from pymongo import MongoClient
from src.config.loader import get_bool_env, get_str_env
class ChatStreamManager:
"""
Manages chat stream messages with persistent storage and in-memory caching.
This class handles the storage and retrieval of chat messages using both
an in-memory store for temporary data and MongoDB or PostgreSQL for persistent storage.
It tracks message chunks and consolidates them when a conversation finishes.
Attributes:
store (InMemoryStore): In-memory storage for temporary message chunks
mongo_client (MongoClient): MongoDB client connection
mongo_db (Database): MongoDB database instance
postgres_conn (psycopg.Connection): PostgreSQL connection
logger (logging.Logger): Logger instance for this class
"""
def __init__(
self, checkpoint_saver: bool = False, db_uri: Optional[str] = None
) -> None:
"""
Initialize the ChatStreamManager with database connections.
Args:
db_uri: Database connection URI. Supports MongoDB (mongodb://) and PostgreSQL (postgresql://)
If None, uses LANGGRAPH_CHECKPOINT_DB_URL env var or defaults to localhost
"""
self.logger = logging.getLogger(__name__)
self.store = InMemoryStore()
self.checkpoint_saver = checkpoint_saver
# Use provided URI or fall back to environment variable or default
self.db_uri = db_uri
# Initialize database connections
self.mongo_client = None
self.mongo_db = None
self.postgres_conn = None
if self.checkpoint_saver:
if self.db_uri.startswith("mongodb://"):
self._init_mongodb()
elif self.db_uri.startswith("postgresql://") or self.db_uri.startswith(
"postgres://"
):
self._init_postgresql()
else:
self.logger.warning(
f"Unsupported database URI scheme: {self.db_uri}. "
"Supported schemes: mongodb://, postgresql://, postgres://"
)
else:
self.logger.warning("Checkpoint saver is disabled")
def _init_mongodb(self) -> None:
"""Initialize MongoDB connection."""
try:
self.mongo_client = MongoClient(self.db_uri)
self.mongo_db = self.mongo_client.checkpointing_db
# Test connection
self.mongo_client.admin.command("ping")
self.logger.info("Successfully connected to MongoDB")
except Exception as e:
self.logger.error(f"Failed to connect to MongoDB: {e}")
def _init_postgresql(self) -> None:
"""Initialize PostgreSQL connection and create table if needed."""
try:
self.postgres_conn = psycopg.connect(self.db_uri, row_factory=dict_row)
self.logger.info("Successfully connected to PostgreSQL")
self._create_chat_streams_table()
except Exception as e:
self.logger.error(f"Failed to connect to PostgreSQL: {e}")
def _create_chat_streams_table(self) -> None:
"""Create the chat_streams table if it doesn't exist."""
try:
with self.postgres_conn.cursor() as cursor:
create_table_sql = """
CREATE TABLE IF NOT EXISTS chat_streams (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
thread_id VARCHAR(255) NOT NULL UNIQUE,
messages JSONB NOT NULL,
ts TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_chat_streams_thread_id ON chat_streams(thread_id);
CREATE INDEX IF NOT EXISTS idx_chat_streams_ts ON chat_streams(ts);
"""
cursor.execute(create_table_sql)
self.postgres_conn.commit()
self.logger.info("Chat streams table created/verified successfully")
except Exception as e:
self.logger.error(f"Failed to create chat_streams table: {e}")
if self.postgres_conn:
self.postgres_conn.rollback()
def process_stream_message(
self, thread_id: str, message: str, finish_reason: str
) -> bool:
"""
Process and store a chat stream message chunk.
This method handles individual message chunks during streaming and consolidates
them into a complete message when the stream finishes. Messages are stored
temporarily in memory and permanently in MongoDB when complete.
Args:
thread_id: Unique identifier for the conversation thread
message: The message content or chunk to store
finish_reason: Reason for message completion ("stop", "interrupt", or partial)
Returns:
bool: True if message was processed successfully, False otherwise
"""
if not thread_id or not isinstance(thread_id, str):
self.logger.warning("Invalid thread_id provided")
return False
if not message:
self.logger.warning("Empty message provided")
return False
try:
# Create namespace for this thread's messages
store_namespace: Tuple[str, str] = ("messages", thread_id)
# Get or initialize message cursor for tracking chunks
cursor = self.store.get(store_namespace, "cursor")
current_index = 0
if cursor is None:
# Initialize cursor for new conversation
self.store.put(store_namespace, "cursor", {"index": 0})
else:
# Increment index for next chunk
current_index = int(cursor.value.get("index", 0)) + 1
self.store.put(store_namespace, "cursor", {"index": current_index})
# Store the current message chunk
self.store.put(store_namespace, f"chunk_{current_index}", message)
# Check if conversation is complete and should be persisted
if finish_reason in ("stop", "interrupt"):
return self._persist_complete_conversation(
thread_id, store_namespace, current_index
)
return True
except Exception as e:
self.logger.error(
f"Error processing stream message for thread {thread_id}: {e}"
)
return False
def _persist_complete_conversation(
self, thread_id: str, store_namespace: Tuple[str, str], final_index: int
) -> bool:
"""
Persist completed conversation to database (MongoDB or PostgreSQL).
Retrieves all message chunks from memory store and saves the complete
conversation to the configured database for permanent storage.
Args:
thread_id: Unique identifier for the conversation thread
store_namespace: Namespace tuple for accessing stored messages
final_index: The final chunk index for this conversation
Returns:
bool: True if persistence was successful, False otherwise
"""
try:
# Retrieve all message chunks from memory store
# Get all messages up to the final index including cursor metadata
memories = self.store.search(store_namespace, limit=final_index + 2)
# Extract message content, filtering out cursor metadata
messages: List[str] = []
for item in memories:
value = item.dict().get("value", "")
# Skip cursor metadata, only include actual message chunks
if value and not isinstance(value, dict):
messages.append(str(value))
if not messages:
self.logger.warning(f"No messages found for thread {thread_id}")
return False
if not self.checkpoint_saver:
self.logger.warning("Checkpoint saver is disabled")
return False
# Choose persistence method based on available connection
success = False
if self.mongo_db is not None:
success = self._persist_to_mongodb(thread_id, messages)
elif self.postgres_conn is not None:
success = self._persist_to_postgresql(thread_id, messages)
else:
self.logger.warning("No database connection available")
return False
if success:
try:
for item in memories:
self.store.delete(store_namespace, item.key)
except Exception as e:
self.logger.error(
f"Error cleaning up memory store for thread {thread_id}: {e}"
)
return success
except Exception as e:
self.logger.error(
f"Error persisting conversation for thread {thread_id}: {e}"
)
return False
def _persist_to_mongodb(self, thread_id: str, messages: List[str]) -> bool:
"""Persist conversation to MongoDB."""
try:
# Get MongoDB collection for chat streams
collection = self.mongo_db.chat_streams
# Check if conversation already exists in database
existing_document = collection.find_one({"thread_id": thread_id})
current_timestamp = datetime.now()
if existing_document:
# Append new messages to existing conversation
update_result = collection.update_one(
{"thread_id": thread_id},
{
"$push": {"messages": {"$each": messages}},
"$set": {"ts": current_timestamp}
},
)
self.logger.info(
f"Updated conversation for thread {thread_id}: "
f"{update_result.modified_count} documents modified"
)
return update_result.modified_count > 0
else:
# Create new conversation document
new_document = {
"thread_id": thread_id,
"messages": messages,
"ts": current_timestamp,
"id": uuid.uuid4().hex,
}
insert_result = collection.insert_one(new_document)
self.logger.info(
f"Created new conversation: {insert_result.inserted_id}"
)
return insert_result.inserted_id is not None
except Exception as e:
self.logger.error(f"Error persisting to MongoDB: {e}")
return False
def _persist_to_postgresql(self, thread_id: str, messages: List[str]) -> bool:
"""Persist conversation to PostgreSQL."""
try:
with self.postgres_conn.cursor() as cursor:
# Check if conversation already exists
cursor.execute(
"SELECT id FROM chat_streams WHERE thread_id = %s", (thread_id,)
)
existing_record = cursor.fetchone()
current_timestamp = datetime.now()
messages_json = json.dumps(messages)
if existing_record:
# Append new messages to existing conversation
cursor.execute(
"""
UPDATE chat_streams
SET messages = messages || %s::jsonb, ts = %s
WHERE thread_id = %s
""",
(messages_json, current_timestamp, thread_id),
)
affected_rows = cursor.rowcount
self.postgres_conn.commit()
self.logger.info(
f"Updated conversation for thread {thread_id}: "
f"{affected_rows} rows modified"
)
return affected_rows > 0
else:
# Create new conversation record
conversation_id = uuid.uuid4()
cursor.execute(
"""
INSERT INTO chat_streams (id, thread_id, messages, ts)
VALUES (%s, %s, %s, %s)
""",
(conversation_id, thread_id, messages_json, current_timestamp),
)
affected_rows = cursor.rowcount
self.postgres_conn.commit()
self.logger.info(
f"Created new conversation with ID: {conversation_id}"
)
return affected_rows > 0
except Exception as e:
self.logger.error(f"Error persisting to PostgreSQL: {e}")
if self.postgres_conn:
self.postgres_conn.rollback()
return False
def close(self) -> None:
"""Close database connections."""
try:
if self.mongo_client is not None:
self.mongo_client.close()
self.logger.info("MongoDB connection closed")
except Exception as e:
self.logger.error(f"Error closing MongoDB connection: {e}")
try:
if self.postgres_conn is not None:
self.postgres_conn.close()
self.logger.info("PostgreSQL connection closed")
except Exception as e:
self.logger.error(f"Error closing PostgreSQL connection: {e}")
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - close connections."""
self.close()
# Global instance for backward compatibility
# TODO: Consider using dependency injection instead of global instance
_default_manager = ChatStreamManager(
checkpoint_saver=get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False),
db_uri=get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "mongodb://localhost:27017"),
)
def chat_stream_message(thread_id: str, message: str, finish_reason: str) -> bool:
"""
Legacy function wrapper for backward compatibility.
Args:
thread_id: Unique identifier for the conversation thread
message: The message content to store
finish_reason: Reason for message completion
Returns:
bool: True if message was processed successfully
"""
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
if checkpoint_saver:
return _default_manager.process_stream_message(
thread_id, message, finish_reason
)
else:
return False