fix: use shared httpx client to prevent premature closure in SSE streaming

The proxy was creating a temporary httpx.AsyncClient within an async context manager.
When returning StreamingResponse for SSE endpoints, the client was being closed before
the streaming generator could use it, causing "client has been closed" errors.

This change introduces a shared httpx.AsyncClient that persists for the application
lifecycle, properly cleaned up during shutdown. This also improves performance by
reusing TCP connections across requests.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
hetaoBackend
2026-01-19 16:52:30 +08:00
parent 3a4149c437
commit ffb9ed3198
2 changed files with 75 additions and 47 deletions

View File

@@ -19,6 +19,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
logger.info(f"Proxying to LangGraph server at {config.langgraph_url}") logger.info(f"Proxying to LangGraph server at {config.langgraph_url}")
yield yield
logger.info("Shutting down API Gateway") logger.info("Shutting down API Gateway")
# Close the shared HTTP client
await proxy.close_http_client()
def create_app() -> FastAPI: def create_app() -> FastAPI:

View File

@@ -11,6 +11,31 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["proxy"]) router = APIRouter(tags=["proxy"])
# Shared httpx client for all proxy requests
# This avoids creating/closing clients during streaming responses
_http_client: httpx.AsyncClient | None = None
def get_http_client() -> httpx.AsyncClient:
"""Get or create the shared HTTP client.
Returns:
The shared httpx AsyncClient instance.
"""
global _http_client
if _http_client is None:
_http_client = httpx.AsyncClient()
return _http_client
async def close_http_client() -> None:
"""Close the shared HTTP client if it exists."""
global _http_client
if _http_client is not None:
await _http_client.aclose()
_http_client = None
# Hop-by-hop headers that should not be forwarded # Hop-by-hop headers that should not be forwarded
EXCLUDED_HEADERS = { EXCLUDED_HEADERS = {
"host", "host",
@@ -76,57 +101,58 @@ async def proxy_request(request: Request, path: str) -> Response | StreamingResp
if request.method not in ("GET", "HEAD"): if request.method not in ("GET", "HEAD"):
body = await request.body() body = await request.body()
async with httpx.AsyncClient() as client: client = get_http_client()
try:
# First, make a non-streaming request to check content type try:
response = await client.request( # First, make a non-streaming request to check content type
method=request.method, response = await client.request(
url=target_url, method=request.method,
headers=headers, url=target_url,
content=body, headers=headers,
timeout=config.proxy_timeout, content=body,
timeout=config.proxy_timeout,
)
content_type = response.headers.get("content-type", "")
# Check if response is SSE (Server-Sent Events)
if "text/event-stream" in content_type:
# For SSE, we need to re-request with streaming
return StreamingResponse(
stream_response(client, request.method, target_url, headers, body, config.stream_timeout),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
) )
content_type = response.headers.get("content-type", "") # Prepare response headers
response_headers = dict(response.headers)
for header in ["transfer-encoding", "connection", "keep-alive"]:
response_headers.pop(header, None)
# Check if response is SSE (Server-Sent Events) return Response(
if "text/event-stream" in content_type: content=response.content,
# For SSE, we need to re-request with streaming status_code=response.status_code,
return StreamingResponse( headers=response_headers,
stream_response(client, request.method, target_url, headers, body, config.stream_timeout), )
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# Prepare response headers except httpx.TimeoutException:
response_headers = dict(response.headers) logger.error(f"Proxy request to {target_url} timed out")
for header in ["transfer-encoding", "connection", "keep-alive"]: return Response(
response_headers.pop(header, None) content='{"error": "Proxy request timed out"}',
status_code=504,
return Response( media_type="application/json",
content=response.content, )
status_code=response.status_code, except httpx.RequestError as e:
headers=response_headers, logger.error(f"Proxy request to {target_url} failed: {e}")
) return Response(
content='{"error": "Proxy request failed"}',
except httpx.TimeoutException: status_code=502,
logger.error(f"Proxy request to {target_url} timed out") media_type="application/json",
return Response( )
content='{"error": "Proxy request timed out"}',
status_code=504,
media_type="application/json",
)
except httpx.RequestError as e:
logger.error(f"Proxy request to {target_url} failed: {e}")
return Response(
content='{"error": "Proxy request failed"}',
status_code=502,
media_type="application/json",
)
@router.api_route( @router.api_route(