mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-05 07:02:13 +08:00
* feat: Implement DeerFlow API server with chat streaming, Langgraph orchestration, and various content generation capabilities. * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * - Use MongoDB `$push` with `$each` to append new messages to existing threads - Use PostgreSQL jsonb concatenation operator to merge messages instead of overwriting - Update comments to reflect append behavior in both database implementations --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
389 lines
15 KiB
Python
389 lines
15 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import List, Optional, Tuple
|
|
|
|
import psycopg
|
|
from langgraph.store.memory import InMemoryStore
|
|
from psycopg.rows import dict_row
|
|
from pymongo import MongoClient
|
|
|
|
from src.config.loader import get_bool_env, get_str_env
|
|
|
|
|
|
class ChatStreamManager:
|
|
"""
|
|
Manages chat stream messages with persistent storage and in-memory caching.
|
|
|
|
This class handles the storage and retrieval of chat messages using both
|
|
an in-memory store for temporary data and MongoDB or PostgreSQL for persistent storage.
|
|
It tracks message chunks and consolidates them when a conversation finishes.
|
|
|
|
Attributes:
|
|
store (InMemoryStore): In-memory storage for temporary message chunks
|
|
mongo_client (MongoClient): MongoDB client connection
|
|
mongo_db (Database): MongoDB database instance
|
|
postgres_conn (psycopg.Connection): PostgreSQL connection
|
|
logger (logging.Logger): Logger instance for this class
|
|
"""
|
|
|
|
def __init__(
|
|
self, checkpoint_saver: bool = False, db_uri: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Initialize the ChatStreamManager with database connections.
|
|
|
|
Args:
|
|
db_uri: Database connection URI. Supports MongoDB (mongodb://) and PostgreSQL (postgresql://)
|
|
If None, uses LANGGRAPH_CHECKPOINT_DB_URL env var or defaults to localhost
|
|
"""
|
|
self.logger = logging.getLogger(__name__)
|
|
self.store = InMemoryStore()
|
|
self.checkpoint_saver = checkpoint_saver
|
|
# Use provided URI or fall back to environment variable or default
|
|
self.db_uri = db_uri
|
|
|
|
# Initialize database connections
|
|
self.mongo_client = None
|
|
self.mongo_db = None
|
|
self.postgres_conn = None
|
|
|
|
if self.checkpoint_saver:
|
|
if self.db_uri.startswith("mongodb://"):
|
|
self._init_mongodb()
|
|
elif self.db_uri.startswith("postgresql://") or self.db_uri.startswith(
|
|
"postgres://"
|
|
):
|
|
self._init_postgresql()
|
|
else:
|
|
self.logger.warning(
|
|
f"Unsupported database URI scheme: {self.db_uri}. "
|
|
"Supported schemes: mongodb://, postgresql://, postgres://"
|
|
)
|
|
else:
|
|
self.logger.warning("Checkpoint saver is disabled")
|
|
|
|
def _init_mongodb(self) -> None:
|
|
"""Initialize MongoDB connection."""
|
|
|
|
try:
|
|
self.mongo_client = MongoClient(self.db_uri)
|
|
self.mongo_db = self.mongo_client.checkpointing_db
|
|
# Test connection
|
|
self.mongo_client.admin.command("ping")
|
|
self.logger.info("Successfully connected to MongoDB")
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to connect to MongoDB: {e}")
|
|
|
|
def _init_postgresql(self) -> None:
|
|
"""Initialize PostgreSQL connection and create table if needed."""
|
|
|
|
try:
|
|
self.postgres_conn = psycopg.connect(self.db_uri, row_factory=dict_row)
|
|
self.logger.info("Successfully connected to PostgreSQL")
|
|
self._create_chat_streams_table()
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to connect to PostgreSQL: {e}")
|
|
|
|
def _create_chat_streams_table(self) -> None:
|
|
"""Create the chat_streams table if it doesn't exist."""
|
|
try:
|
|
with self.postgres_conn.cursor() as cursor:
|
|
create_table_sql = """
|
|
CREATE TABLE IF NOT EXISTS chat_streams (
|
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|
thread_id VARCHAR(255) NOT NULL UNIQUE,
|
|
messages JSONB NOT NULL,
|
|
ts TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_chat_streams_thread_id ON chat_streams(thread_id);
|
|
CREATE INDEX IF NOT EXISTS idx_chat_streams_ts ON chat_streams(ts);
|
|
"""
|
|
cursor.execute(create_table_sql)
|
|
self.postgres_conn.commit()
|
|
self.logger.info("Chat streams table created/verified successfully")
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to create chat_streams table: {e}")
|
|
if self.postgres_conn:
|
|
self.postgres_conn.rollback()
|
|
|
|
def process_stream_message(
|
|
self, thread_id: str, message: str, finish_reason: str
|
|
) -> bool:
|
|
"""
|
|
Process and store a chat stream message chunk.
|
|
|
|
This method handles individual message chunks during streaming and consolidates
|
|
them into a complete message when the stream finishes. Messages are stored
|
|
temporarily in memory and permanently in MongoDB when complete.
|
|
|
|
Args:
|
|
thread_id: Unique identifier for the conversation thread
|
|
message: The message content or chunk to store
|
|
finish_reason: Reason for message completion ("stop", "interrupt", or partial)
|
|
|
|
Returns:
|
|
bool: True if message was processed successfully, False otherwise
|
|
"""
|
|
if not thread_id or not isinstance(thread_id, str):
|
|
self.logger.warning("Invalid thread_id provided")
|
|
return False
|
|
|
|
if not message:
|
|
self.logger.warning("Empty message provided")
|
|
return False
|
|
|
|
try:
|
|
# Create namespace for this thread's messages
|
|
store_namespace: Tuple[str, str] = ("messages", thread_id)
|
|
|
|
# Get or initialize message cursor for tracking chunks
|
|
cursor = self.store.get(store_namespace, "cursor")
|
|
current_index = 0
|
|
|
|
if cursor is None:
|
|
# Initialize cursor for new conversation
|
|
self.store.put(store_namespace, "cursor", {"index": 0})
|
|
else:
|
|
# Increment index for next chunk
|
|
current_index = int(cursor.value.get("index", 0)) + 1
|
|
self.store.put(store_namespace, "cursor", {"index": current_index})
|
|
|
|
# Store the current message chunk
|
|
self.store.put(store_namespace, f"chunk_{current_index}", message)
|
|
|
|
# Check if conversation is complete and should be persisted
|
|
if finish_reason in ("stop", "interrupt"):
|
|
return self._persist_complete_conversation(
|
|
thread_id, store_namespace, current_index
|
|
)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error processing stream message for thread {thread_id}: {e}"
|
|
)
|
|
return False
|
|
|
|
def _persist_complete_conversation(
|
|
self, thread_id: str, store_namespace: Tuple[str, str], final_index: int
|
|
) -> bool:
|
|
"""
|
|
Persist completed conversation to database (MongoDB or PostgreSQL).
|
|
|
|
Retrieves all message chunks from memory store and saves the complete
|
|
conversation to the configured database for permanent storage.
|
|
|
|
Args:
|
|
thread_id: Unique identifier for the conversation thread
|
|
store_namespace: Namespace tuple for accessing stored messages
|
|
final_index: The final chunk index for this conversation
|
|
|
|
Returns:
|
|
bool: True if persistence was successful, False otherwise
|
|
"""
|
|
try:
|
|
# Retrieve all message chunks from memory store
|
|
# Get all messages up to the final index including cursor metadata
|
|
memories = self.store.search(store_namespace, limit=final_index + 2)
|
|
|
|
# Extract message content, filtering out cursor metadata
|
|
messages: List[str] = []
|
|
for item in memories:
|
|
value = item.dict().get("value", "")
|
|
# Skip cursor metadata, only include actual message chunks
|
|
if value and not isinstance(value, dict):
|
|
messages.append(str(value))
|
|
|
|
if not messages:
|
|
self.logger.warning(f"No messages found for thread {thread_id}")
|
|
return False
|
|
|
|
if not self.checkpoint_saver:
|
|
self.logger.warning("Checkpoint saver is disabled")
|
|
return False
|
|
|
|
# Choose persistence method based on available connection
|
|
success = False
|
|
if self.mongo_db is not None:
|
|
success = self._persist_to_mongodb(thread_id, messages)
|
|
elif self.postgres_conn is not None:
|
|
success = self._persist_to_postgresql(thread_id, messages)
|
|
else:
|
|
self.logger.warning("No database connection available")
|
|
return False
|
|
|
|
if success:
|
|
try:
|
|
for item in memories:
|
|
self.store.delete(store_namespace, item.key)
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error cleaning up memory store for thread {thread_id}: {e}"
|
|
)
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error persisting conversation for thread {thread_id}: {e}"
|
|
)
|
|
return False
|
|
|
|
def _persist_to_mongodb(self, thread_id: str, messages: List[str]) -> bool:
|
|
"""Persist conversation to MongoDB."""
|
|
try:
|
|
# Get MongoDB collection for chat streams
|
|
collection = self.mongo_db.chat_streams
|
|
|
|
# Check if conversation already exists in database
|
|
existing_document = collection.find_one({"thread_id": thread_id})
|
|
|
|
current_timestamp = datetime.now()
|
|
|
|
if existing_document:
|
|
# Append new messages to existing conversation
|
|
update_result = collection.update_one(
|
|
{"thread_id": thread_id},
|
|
{
|
|
"$push": {"messages": {"$each": messages}},
|
|
"$set": {"ts": current_timestamp}
|
|
},
|
|
)
|
|
self.logger.info(
|
|
f"Updated conversation for thread {thread_id}: "
|
|
f"{update_result.modified_count} documents modified"
|
|
)
|
|
return update_result.modified_count > 0
|
|
else:
|
|
# Create new conversation document
|
|
new_document = {
|
|
"thread_id": thread_id,
|
|
"messages": messages,
|
|
"ts": current_timestamp,
|
|
"id": uuid.uuid4().hex,
|
|
}
|
|
insert_result = collection.insert_one(new_document)
|
|
self.logger.info(
|
|
f"Created new conversation: {insert_result.inserted_id}"
|
|
)
|
|
return insert_result.inserted_id is not None
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error persisting to MongoDB: {e}")
|
|
return False
|
|
|
|
def _persist_to_postgresql(self, thread_id: str, messages: List[str]) -> bool:
|
|
"""Persist conversation to PostgreSQL."""
|
|
try:
|
|
with self.postgres_conn.cursor() as cursor:
|
|
# Check if conversation already exists
|
|
cursor.execute(
|
|
"SELECT id FROM chat_streams WHERE thread_id = %s", (thread_id,)
|
|
)
|
|
existing_record = cursor.fetchone()
|
|
|
|
current_timestamp = datetime.now()
|
|
messages_json = json.dumps(messages)
|
|
|
|
if existing_record:
|
|
# Append new messages to existing conversation
|
|
cursor.execute(
|
|
"""
|
|
UPDATE chat_streams
|
|
SET messages = messages || %s::jsonb, ts = %s
|
|
WHERE thread_id = %s
|
|
""",
|
|
(messages_json, current_timestamp, thread_id),
|
|
)
|
|
affected_rows = cursor.rowcount
|
|
self.postgres_conn.commit()
|
|
|
|
self.logger.info(
|
|
f"Updated conversation for thread {thread_id}: "
|
|
f"{affected_rows} rows modified"
|
|
)
|
|
return affected_rows > 0
|
|
else:
|
|
# Create new conversation record
|
|
conversation_id = uuid.uuid4()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO chat_streams (id, thread_id, messages, ts)
|
|
VALUES (%s, %s, %s, %s)
|
|
""",
|
|
(conversation_id, thread_id, messages_json, current_timestamp),
|
|
)
|
|
affected_rows = cursor.rowcount
|
|
self.postgres_conn.commit()
|
|
|
|
self.logger.info(
|
|
f"Created new conversation with ID: {conversation_id}"
|
|
)
|
|
return affected_rows > 0
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error persisting to PostgreSQL: {e}")
|
|
if self.postgres_conn:
|
|
self.postgres_conn.rollback()
|
|
return False
|
|
|
|
def close(self) -> None:
|
|
"""Close database connections."""
|
|
try:
|
|
if self.mongo_client is not None:
|
|
self.mongo_client.close()
|
|
self.logger.info("MongoDB connection closed")
|
|
except Exception as e:
|
|
self.logger.error(f"Error closing MongoDB connection: {e}")
|
|
|
|
try:
|
|
if self.postgres_conn is not None:
|
|
self.postgres_conn.close()
|
|
self.logger.info("PostgreSQL connection closed")
|
|
except Exception as e:
|
|
self.logger.error(f"Error closing PostgreSQL connection: {e}")
|
|
|
|
def __enter__(self):
|
|
"""Context manager entry."""
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Context manager exit - close connections."""
|
|
self.close()
|
|
|
|
|
|
# Global instance for backward compatibility
|
|
# TODO: Consider using dependency injection instead of global instance
|
|
_default_manager = ChatStreamManager(
|
|
checkpoint_saver=get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False),
|
|
db_uri=get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "mongodb://localhost:27017"),
|
|
)
|
|
|
|
|
|
def chat_stream_message(thread_id: str, message: str, finish_reason: str) -> bool:
|
|
"""
|
|
Legacy function wrapper for backward compatibility.
|
|
|
|
Args:
|
|
thread_id: Unique identifier for the conversation thread
|
|
message: The message content to store
|
|
finish_reason: Reason for message completion
|
|
|
|
Returns:
|
|
bool: True if message was processed successfully
|
|
"""
|
|
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
|
|
if checkpoint_saver:
|
|
return _default_manager.process_stream_message(
|
|
thread_id, message, finish_reason
|
|
)
|
|
else:
|
|
return False
|