diff --git a/backend/src/gateway/app.py b/backend/src/gateway/app.py index 796df12..443587f 100644 --- a/backend/src/gateway/app.py +++ b/backend/src/gateway/app.py @@ -19,6 +19,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: logger.info(f"Proxying to LangGraph server at {config.langgraph_url}") yield logger.info("Shutting down API Gateway") + # Close the shared HTTP client + await proxy.close_http_client() def create_app() -> FastAPI: diff --git a/backend/src/gateway/routers/proxy.py b/backend/src/gateway/routers/proxy.py index f726cde..f307a22 100644 --- a/backend/src/gateway/routers/proxy.py +++ b/backend/src/gateway/routers/proxy.py @@ -11,6 +11,31 @@ logger = logging.getLogger(__name__) router = APIRouter(tags=["proxy"]) +# Shared httpx client for all proxy requests +# This avoids creating/closing clients during streaming responses +_http_client: httpx.AsyncClient | None = None + + +def get_http_client() -> httpx.AsyncClient: + """Get or create the shared HTTP client. + + Returns: + The shared httpx AsyncClient instance. + """ + global _http_client + if _http_client is None: + _http_client = httpx.AsyncClient() + return _http_client + + +async def close_http_client() -> None: + """Close the shared HTTP client if it exists.""" + global _http_client + if _http_client is not None: + await _http_client.aclose() + _http_client = None + + # Hop-by-hop headers that should not be forwarded EXCLUDED_HEADERS = { "host", @@ -76,57 +101,58 @@ async def proxy_request(request: Request, path: str) -> Response | StreamingResp if request.method not in ("GET", "HEAD"): body = await request.body() - async with httpx.AsyncClient() as client: - try: - # First, make a non-streaming request to check content type - response = await client.request( - method=request.method, - url=target_url, - headers=headers, - content=body, - timeout=config.proxy_timeout, + client = get_http_client() + + try: + # First, make a non-streaming request to check content type + response = await client.request( + method=request.method, + url=target_url, + headers=headers, + content=body, + timeout=config.proxy_timeout, + ) + + content_type = response.headers.get("content-type", "") + + # Check if response is SSE (Server-Sent Events) + if "text/event-stream" in content_type: + # For SSE, we need to re-request with streaming + return StreamingResponse( + stream_response(client, request.method, target_url, headers, body, config.stream_timeout), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, ) - content_type = response.headers.get("content-type", "") + # Prepare response headers + response_headers = dict(response.headers) + for header in ["transfer-encoding", "connection", "keep-alive"]: + response_headers.pop(header, None) - # Check if response is SSE (Server-Sent Events) - if "text/event-stream" in content_type: - # For SSE, we need to re-request with streaming - return StreamingResponse( - stream_response(client, request.method, target_url, headers, body, config.stream_timeout), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) + return Response( + content=response.content, + status_code=response.status_code, + headers=response_headers, + ) - # Prepare response headers - response_headers = dict(response.headers) - for header in ["transfer-encoding", "connection", "keep-alive"]: - response_headers.pop(header, None) - - return Response( - content=response.content, - status_code=response.status_code, - headers=response_headers, - ) - - except httpx.TimeoutException: - logger.error(f"Proxy request to {target_url} timed out") - return Response( - content='{"error": "Proxy request timed out"}', - status_code=504, - media_type="application/json", - ) - except httpx.RequestError as e: - logger.error(f"Proxy request to {target_url} failed: {e}") - return Response( - content='{"error": "Proxy request failed"}', - status_code=502, - media_type="application/json", - ) + except httpx.TimeoutException: + logger.error(f"Proxy request to {target_url} timed out") + return Response( + content='{"error": "Proxy request timed out"}', + status_code=504, + media_type="application/json", + ) + except httpx.RequestError as e: + logger.error(f"Proxy request to {target_url} failed: {e}") + return Response( + content='{"error": "Proxy request failed"}', + status_code=502, + media_type="application/json", + ) @router.api_route(