feat:Database connections use connection pools (#757)

* feat: Implement DeerFlow API server with chat streaming, Langgraph orchestration, and various content generation capabilities.

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
YikB
2025-12-23 20:35:08 +08:00
committed by GitHub
parent 1f403a9f79
commit 83e9d7c9e5

View File

@@ -20,7 +20,7 @@ from langgraph.types import Command
from psycopg_pool import AsyncConnectionPool
from src.config.configuration import get_recursion_limit
from src.config.loader import get_bool_env, get_str_env
from src.config.loader import get_bool_env, get_int_env, get_str_env
from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory
@@ -73,10 +73,135 @@ if os.name == "nt":
INTERNAL_SERVER_ERROR_DETAIL = "Internal Server Error"
# Global connection pools (initialized at startup if configured)
_pg_pool: Optional[AsyncConnectionPool] = None
_pg_checkpointer: Optional[AsyncPostgresSaver] = None
# Global MongoDB connection (initialized at startup if configured)
_mongo_client: Optional[Any] = None
_mongo_checkpointer: Optional[AsyncMongoDBSaver] = None
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app):
"""
Application lifecycle manager
- Startup: Register asyncio exception handler and initialize global connection pools
- Shutdown: Clean up global connection pools
"""
global _pg_pool, _pg_checkpointer, _mongo_client, _mongo_checkpointer
# ========== STARTUP ==========
try:
asyncio.get_running_loop()
except RuntimeError as e:
logger.warning(f"Could not register asyncio exception handler: {e}")
# Initialize global connection pool based on configuration
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
if not checkpoint_saver or not checkpoint_url:
logger.info("Checkpoint saver not configured, skipping connection pool initialization")
else:
# Initialize PostgreSQL connection pool
if checkpoint_url.startswith("postgresql://"):
pool_min_size = get_int_env("PG_POOL_MIN_SIZE", 5)
pool_max_size = get_int_env("PG_POOL_MAX_SIZE", 20)
pool_timeout = get_int_env("PG_POOL_TIMEOUT", 60)
connection_kwargs = {
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
}
logger.info(
f"Initializing global PostgreSQL connection pool: "
f"min_size={pool_min_size}, max_size={pool_max_size}, timeout={pool_timeout}s"
)
try:
_pg_pool = AsyncConnectionPool(
checkpoint_url,
kwargs=connection_kwargs,
min_size=pool_min_size,
max_size=pool_max_size,
timeout=pool_timeout,
)
await _pg_pool.open()
_pg_checkpointer = AsyncPostgresSaver(_pg_pool)
await _pg_checkpointer.setup()
logger.info("Global PostgreSQL connection pool initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL connection pool: {e}")
_pg_pool = None
_pg_checkpointer = None
raise RuntimeError(
"Checkpoint persistence is explicitly configured with PostgreSQL, "
"but initialization failed. Application will not start."
) from e
# Initialize MongoDB connection pool
elif checkpoint_url.startswith("mongodb://"):
try:
from motor.motor_asyncio import AsyncIOMotorClient
# MongoDB connection pool settings
mongo_max_pool_size = get_int_env("MONGO_MAX_POOL_SIZE", 20)
mongo_min_pool_size = get_int_env("MONGO_MIN_POOL_SIZE", 5)
logger.info(
f"Initializing global MongoDB connection pool: "
f"min_pool_size={mongo_min_pool_size}, max_pool_size={mongo_max_pool_size}"
)
_mongo_client = AsyncIOMotorClient(
checkpoint_url,
maxPoolSize=mongo_max_pool_size,
minPoolSize=mongo_min_pool_size,
)
# Create the MongoDB checkpointer using the global client
_mongo_checkpointer = AsyncMongoDBSaver(_mongo_client)
await _mongo_checkpointer.setup()
logger.info("Global MongoDB connection pool initialized successfully")
except ImportError:
logger.error("motor package not installed. Please install it with: pip install motor")
raise RuntimeError("MongoDB checkpoint persistence is configured but the 'motor' package is not installed. Aborting startup.")
except Exception as e:
logger.error(f"Failed to initialize MongoDB connection pool: {e}")
raise RuntimeError(f"MongoDB checkpoint persistence is configured but could not be initialized: {e}")
# ========== YIELD - Application runs here ==========
yield
# ========== SHUTDOWN ==========
# Close PostgreSQL connection pool
if _pg_pool:
logger.info("Closing global PostgreSQL connection pool")
await _pg_pool.close()
logger.info("Global PostgreSQL connection pool closed")
# Close MongoDB connection
if _mongo_client:
logger.info("Closing global MongoDB connection")
_mongo_client.close()
logger.info("Global MongoDB connection closed")
app = FastAPI(
title="DeerFlow API",
description="API for Deer",
version="0.1.0",
lifespan=lifespan,
)
# Add CORS middleware
@@ -612,23 +737,33 @@ async def _astream_workflow_generator(
f"url_configured={bool(checkpoint_url)}"
)
# Handle checkpointer if configured
connection_kwargs = {
"autocommit": True,
"row_factory": "dict_row",
"prepare_threshold": 0,
}
# Handle checkpointer if configured - prefer global connection pools
if checkpoint_saver and checkpoint_url != "":
if checkpoint_url.startswith("postgresql://"):
logger.info(f"[{safe_thread_id}] Starting async postgres checkpointer")
logger.debug(f"[{safe_thread_id}] Setting up PostgreSQL connection pool")
# Try to use global PostgreSQL checkpointer first
if checkpoint_url.startswith("postgresql://") and _pg_checkpointer:
logger.info(f"[{safe_thread_id}] Using global PostgreSQL connection pool")
graph.checkpointer = _pg_checkpointer
graph.store = in_memory_store
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id
):
yield event
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
# Fallback to per-request PostgreSQL connection if global pool not available
elif checkpoint_url.startswith("postgresql://"):
logger.info(f"[{safe_thread_id}] Global pool unavailable, creating per-request PostgreSQL connection")
connection_kwargs = {
"autocommit": True,
"row_factory": "dict_row",
"prepare_threshold": 0,
}
async with AsyncConnectionPool(
checkpoint_url, kwargs=connection_kwargs
) as conn:
logger.debug(f"[{safe_thread_id}] Initializing AsyncPostgresSaver")
checkpointer = AsyncPostgresSaver(conn)
await checkpointer.setup()
logger.debug(f"[{safe_thread_id}] Attaching checkpointer to graph")
graph.checkpointer = checkpointer
graph.store = in_memory_store
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
@@ -638,13 +773,24 @@ async def _astream_workflow_generator(
yield event
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
if checkpoint_url.startswith("mongodb://"):
logger.info(f"[{safe_thread_id}] Starting async mongodb checkpointer")
logger.debug(f"[{safe_thread_id}] Setting up MongoDB connection")
# Try to use global MongoDB checkpointer first
elif checkpoint_url.startswith("mongodb://") and _mongo_checkpointer:
logger.info(f"[{safe_thread_id}] Using global MongoDB connection pool")
graph.checkpointer = _mongo_checkpointer
graph.store = in_memory_store
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id
):
yield event
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
# Fallback to per-request MongoDB connection if global pool not available
elif checkpoint_url.startswith("mongodb://"):
logger.info(f"[{safe_thread_id}] Global pool unavailable, creating per-request MongoDB connection")
async with AsyncMongoDBSaver.from_conn_string(
checkpoint_url
) as checkpointer:
logger.debug(f"[{safe_thread_id}] Attaching MongoDB checkpointer to graph")
graph.checkpointer = checkpointer
graph.store = in_memory_store
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
@@ -655,7 +801,7 @@ async def _astream_workflow_generator(
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
else:
logger.debug(f"[{safe_thread_id}] No checkpointer configured, using in-memory graph")
# Use graph without MongoDB checkpointer
# Use graph without checkpointer
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id